diff --git a/client_example/05_layernorm/layernorm2d.cpp b/client_example/05_layernorm/layernorm2d.cpp index 3ee7cead7b..7a8e5fec99 100644 --- a/client_example/05_layernorm/layernorm2d.cpp +++ b/client_example/05_layernorm/layernorm2d.cpp @@ -12,12 +12,14 @@ #include "ck/library/tensor_operation_instance/gpu/normalization.hpp" -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD constexpr int Rank = 2; constexpr int NumReduceDim = 1; @@ -50,12 +52,16 @@ int main(int argc, char* argv[]) SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * M); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * M); +#endif using DeviceOp = ck::tensor_operation::device::DeviceNormalization; @@ -84,14 +90,21 @@ int main(int argc, char* argv[]) {0, 1}, // gammaStrides {0, 1}, // betaStrides {Stride, 1}, // yStrides + {1}, // save_mean Strides + {1}, // save_inv_std Strides {1}, // reduceDims 1e-4, x_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -109,6 +122,10 @@ int main(int argc, char* argv[]) std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * M * 2; +#endif + float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " @@ -140,17 +157,24 @@ int main(int argc, char* argv[]) auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths {Stride, 1}, // xStrides - {1}, // gammaStrides - {1}, // betaStrides + {0, 1}, // gammaStrides + {0, 1}, // betaStrides {Stride, 1}, // yStrides + {1}, // save_mean Strides + {1}, // save_inv_std Strides {1}, // reduceDims 1e-4, x_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/client_example/18_groupnorm/groupnorm_swish.cpp b/client_example/18_groupnorm/groupnorm_swish.cpp index df0a9ceec6..abe7492c65 100644 --- a/client_example/18_groupnorm/groupnorm_swish.cpp +++ b/client_example/18_groupnorm/groupnorm_swish.cpp @@ -12,12 +12,14 @@ #include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" -using XDataType = ck::half_t; -using GammaDataType = float; -using BetaDataType = float; -using YDataType = ck::half_t; -using ComputeDataType = float; -using Swish = ck::tensor_operation::element_wise::Swish; +using XDataType = ck::half_t; +using GammaDataType = float; +using BetaDataType = float; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using Swish = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD constexpr int Rank = 5; constexpr int NumReduceDim = 3; @@ -49,19 +51,24 @@ int main(int argc, char* argv[]) std::size_t xy_size = N * H * W * G * C; std::size_t gamma_beta_size = G * C; - std::vector xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; - std::vector gamma_beta_strides = {0, 0, 0, C, 1}; + std::vector xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector gamma_beta_strides = {0, 0, 0, C, 1}; + std::vector save_mean_inv_std_strides = {G, 1}; SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N * G); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N * G); +#endif using DeviceOp = ck::tensor_operation::device::DeviceNormalization; @@ -75,19 +82,26 @@ int main(int argc, char* argv[]) const auto& generic_op_ptr = op_ptrs[0]; auto generic_argument_ptr = - generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims + generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims 1e-6, x_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif Swish{}); if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) @@ -107,21 +121,29 @@ int main(int argc, char* argv[]) for(int i = 0; i < op_ptrs.size(); ++i) { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims - 1e-6, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - Swish{}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -139,6 +161,10 @@ int main(int argc, char* argv[]) sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size + sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size; +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * N * G * 2; +#endif + float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " @@ -169,20 +195,28 @@ int main(int argc, char* argv[]) std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims - 1e-6, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - Swish{}); + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index fc58ca19f8..6a92e9a2f5 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor& h_m_n, BetaDataType, HDataType, AccDataType, + AccDataType, HElementOp, 2, 1>; Tensor e_m_n(HostTensorDescriptor{M, N}); Tensor c_m_n(HostTensorDescriptor{M, N}); + Tensor save_mean({M}); + Tensor save_inv_std({M}); auto ref_gemm = ReferenceGemm{}; auto ref_gemm_invoker = ref_gemm.MakeInvoker(); @@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor& h_m_n, auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_argument = ref_layernorm.MakeArgument( - e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); + e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon); ref_layernorm_invoker.Run(ref_layernorm_argument); } diff --git a/example/27_layernorm/layernorm_fp16.cpp b/example/27_layernorm/layernorm_fp16.cpp index bb8b954f0a..255452e769 100644 --- a/example/27_layernorm/layernorm_fp16.cpp +++ b/example/27_layernorm/layernorm_fp16.cpp @@ -3,12 +3,15 @@ #include "common.hpp" -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD constexpr int Rank = 2; constexpr int NumReduceDim = 1; @@ -19,6 +22,7 @@ using DeviceInstance = BetaDataType, ComputeDataType, YDataType, + SaveMeanInvStdDataType, PassThrough, Rank, NumReduceDim, @@ -33,7 +37,8 @@ using DeviceInstance = 8, // GammaScalarPerVector 1, // BetaVecDim (0=M, 1=K) 8, // BetaScalarPerVector - 8>; // OutScalarPerVector + 8, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector #include "run_layernorm_example.inc" int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/layernorm_splitk_fp16.cpp b/example/27_layernorm/layernorm_splitk_fp16.cpp index e0378d028b..e2a85bddc5 100644 --- a/example/27_layernorm/layernorm_splitk_fp16.cpp +++ b/example/27_layernorm/layernorm_splitk_fp16.cpp @@ -3,12 +3,15 @@ #include "common.hpp" -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD constexpr int Rank = 2; constexpr int NumReduceDim = 1; @@ -19,6 +22,7 @@ using DeviceInstance = BetaDataType, ComputeDataType, YDataType, + SaveMeanInvStdDataType, PassThrough, Rank, NumReduceDim, @@ -33,7 +37,8 @@ using DeviceInstance = 8, // GammaScalarPerVector 1, // BetaVecDim (0=M, 1=K) 8, // BetaScalarPerVector - 8>; // YScalarPerVector + 8, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector #include "run_layernorm_example.inc" diff --git a/example/27_layernorm/run_layernorm_example.inc b/example/27_layernorm/run_layernorm_example.inc index 95200b540a..c8f599a39c 100644 --- a/example/27_layernorm/run_layernorm_example.inc +++ b/example/27_layernorm/run_layernorm_example.inc @@ -10,22 +10,13 @@ int run_groupnorm_example() ck::index_t M = 1024; ck::index_t N = 1024; - ck::index_t Stride = N; - auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { - return HostTensorDescriptor({len}, {stride}); - }; - - auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { - using namespace ck::literals; - - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - }; - - Tensor x(f_host_tensor_descriptor2d(M, N, Stride)); - Tensor gamma(f_host_tensor_descriptor1d(N, 1)); - Tensor beta(f_host_tensor_descriptor1d(N, 1)); - Tensor y(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor x({M, N}); + Tensor gamma({N}); + Tensor beta({N}); + Tensor y({M, N}); + Tensor save_mean({M}); + Tensor save_inv_std({M}); x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -35,6 +26,11 @@ int run_groupnorm_example() DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif x_dev.ToDevice(x.mData.data()); gamma_dev.ToDevice(gamma.mData.data()); @@ -47,14 +43,23 @@ int run_groupnorm_example() {0, 1}, {0, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, {1}, 1e-4, x_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif PassThrough{}); if(!device_instance.IsSupportedArgument(argument_ptr.get())) @@ -72,24 +77,45 @@ int run_groupnorm_example() bool pass = true; { - Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); - using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; + Tensor host_y({M, N}); + Tensor host_save_mean({M}); + Tensor host_save_inv_std({M}); + + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernorm; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + PassThrough{}, + {M, N}, + {1}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); - pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif } return (pass ? 0 : 1); diff --git a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp index b36bd761b3..0ede570e62 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp @@ -6,11 +6,14 @@ constexpr int Rank = 5; constexpr int NumReduceDim = 3; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; + +#define SAVE_MEAN_INV_STD struct YElementOp { @@ -39,6 +42,7 @@ using DeviceInstance = BetaDataType, ComputeDataType, YDataType, + SaveMeanInvStdDataType, YElementOp, Rank, NumReduceDim, @@ -53,7 +57,8 @@ using DeviceInstance = 2, // GammaScalarPerVector 1, // BetaVecDim (0=M, 1=K) 2, // BetaScalarPerVector - 2>; // OutScalarPerVector + 2, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector #include "run_groupnorm_example.inc" diff --git a/example/42_groupnorm/groupnorm_splitk_fp16.cpp b/example/42_groupnorm/groupnorm_splitk_fp16.cpp index 057b240a63..5f56268e02 100644 --- a/example/42_groupnorm/groupnorm_splitk_fp16.cpp +++ b/example/42_groupnorm/groupnorm_splitk_fp16.cpp @@ -6,12 +6,15 @@ constexpr int Rank = 5; constexpr int NumReduceDim = 3; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using YElementOp = ck::tensor_operation::element_wise::Swish; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationSplitKImpl; // OutScalarPerVector + 2, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector #include "run_groupnorm_example.inc" diff --git a/example/42_groupnorm/groupnorm_swish_fp16.cpp b/example/42_groupnorm/groupnorm_swish_fp16.cpp index 363f22ed4c..97cd4698aa 100644 --- a/example/42_groupnorm/groupnorm_swish_fp16.cpp +++ b/example/42_groupnorm/groupnorm_swish_fp16.cpp @@ -6,12 +6,15 @@ constexpr int Rank = 5; constexpr int NumReduceDim = 3; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using YElementOp = ck::tensor_operation::element_wise::Swish; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector + 2, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector #include "run_groupnorm_example.inc" diff --git a/example/42_groupnorm/run_groupnorm_example.inc b/example/42_groupnorm/run_groupnorm_example.inc index 16065c8d46..da41e90639 100644 --- a/example/42_groupnorm/run_groupnorm_example.inc +++ b/example/42_groupnorm/run_groupnorm_example.inc @@ -34,6 +34,8 @@ int run_groupnorm_example(int argc, char* argv[]) Tensor y({N, H, W, G, C}); Tensor gamma({G, C}); Tensor beta({G, C}); + Tensor save_mean({N, G}); + Tensor save_inv_std({N, G}); ck::utils::FillUniformDistribution{0.f, 1.f}(x); ck::utils::FillUniformDistribution{0.f, 1.f}(gamma); @@ -43,6 +45,11 @@ int run_groupnorm_example(int argc, char* argv[]) DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif x_dev.ToDevice(x.mData.data()); gamma_dev.ToDevice(gamma.mData.data()); @@ -57,14 +64,23 @@ int run_groupnorm_example(int argc, char* argv[]) {0, 0, 0, C, 1}, {0, 0, 0, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, {1, 2, 4}, // reduction dimension: [H, W, C] 1e-6, x_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif y_element_op); if(!device_instance.IsSupportedArgument(argument_ptr.get())) @@ -92,21 +108,40 @@ int run_groupnorm_example(int argc, char* argv[]) bool pass = true; { Tensor host_y({N, H, W, G, C}); - using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm; + Tensor host_save_mean(HostTensorDescriptor{N, G}); + Tensor host_save_inv_std(HostTensorDescriptor{N, G}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnorm; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, y_element_op, {N, H, W, G, C}, 1e-6); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + y_element_op, + {N, H, W, G, C}, + 1e-6); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif } return (pass ? 0 : 1); diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 76361f87a5..c02d540983 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -167,20 +167,31 @@ int main() XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{}); Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor host_save_mean({M}); + Tensor host_save_inv_std({M}); using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, YElementwiseOperation{}, {M, N}, {1}, 1e-4); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + YElementwiseOperation{}, + {M, N}, + {1}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp index 1f178f9fcb..97e83ebab2 100644 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -14,8 +14,8 @@ namespace device { template @@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator const std::vector gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, double epsilon, const void* p_x, @@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator template using DeviceNormalizationPtr = std::unique_ptr>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp index ea0d805043..1ef3350185 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp @@ -28,6 +28,7 @@ template struct DeviceNormalizationImpl : public DeviceNormalization @@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization& inLengths, const std::vector& inStrides, int numBlockTileIteration) { - constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t numSrcDim = Rank; - static constexpr bool reduceAllDim = (NumInvariantDim == 0); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); @@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); + using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); struct Argument : public BaseArgument { @@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, YElementwiseOperation y_elementwise_op, double epsilon, const XDataType* p_x, const GammaDataType* p_gamma, const BetaDataType* p_beta, - YDataType* p_y) + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) : p_x_(p_x), p_gamma_(p_gamma), p_beta_(p_beta), p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), y_elementwise_op_(y_elementwise_op) { epsilon_ = static_cast(epsilon); @@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization(yStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; - long_index_t invariant_length; - long_index_t reduce_length; + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); - std::tie(invariant_length, reduce_length) = - get_2d_lengths(Lengths_); + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); - numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize); - - gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize); + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); gamma_grid_desc_m_k_ = @@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization{}) <= KThreadClusterSize * KThreadSliceSize; + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; } ComputeDataType epsilon_; @@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization Lengths_; std::vector xStrides_; std::vector gammaStrides_; std::vector betaStrides_; std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; YElementwiseOperation y_elementwise_op_; @@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization(arg.isSweeponce_); float avg_time = 0; @@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization(p_arg); - constexpr index_t NumInvariantDim = Rank - NumReduceDim; - if constexpr(XYSrcVectorDim == 0) { if constexpr(NumInvariantDim == 0) @@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalizationinvariant_lowest_length_); + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) return false; - if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) return false; - if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) return false; }; } @@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalizationbetaStrides_[NumInvariantDim - 1] != 1) return (false); - if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) return (false); } else // if fastest dim is reduced @@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalizationinvariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + return true; }; @@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, double epsilon, const void* p_x, @@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization(lengths, xStrides, gammaStrides, betaStrides, yStrides, + saveMeanStrides, + saveInvStdStrides, reduceDims, y_elementwise_op, epsilon, static_cast(p_x), static_cast(p_gamma), static_cast(p_beta), - static_cast(p_y)); + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); }; std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp index 8b2b3c41bf..e969e493c6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp @@ -19,7 +19,7 @@ namespace ck { template @@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, index_t num_k_block_tile_iteration, const XDataType* const __restrict__ p_x_global, - MeanVarDataType* const __restrict__ p_welford_mean, - MeanVarDataType* const __restrict__ p_welford_variance, + WorkspaceMeanVarDataType* const __restrict__ p_welford_mean, + WorkspaceMeanVarDataType* const __restrict__ p_welford_variance, int32_t* const __restrict__ p_welford_count) { GridwiseWelford::Run(x_grid_desc_m_k, @@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, }; template + typename XYGammaBetaGridDesc_M_K, + typename SaveMeanInvStdGridDesc_M> __global__ void kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock count_grid_desc_m_kblock, @@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, - const MeanVarDataType* const p_mean_global, - const MeanVarDataType* const p_variance_global, + const WorkspaceMeanVarDataType* const p_mean_global, + const WorkspaceMeanVarDataType* const p_variance_global, const int32_t* const p_welford_count_global, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, @@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ gamma_grid_desc_m_k, beta_grid_desc_m_k, y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, num_k_mean_var_count_iteration, num_k_block_tile_iteration, k_grid_size, @@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ p_gamma_global, p_beta_global, p_y_global, + p_save_mean_global, + p_save_inv_std_global, y_elementwise_op); }; } // namespace ck @@ -107,6 +117,7 @@ template + index_t YDstVectorSize, + index_t SaveMeanInvStdDstVectorSize> struct DeviceNormalizationSplitKImpl : public DeviceNormalization { - using MeanVarDataType = ComputeDataType; + using WorkspaceMeanVarDataType = SaveMeanInvStdDataType; static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); static_assert( @@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization{}; static constexpr auto I1 = Number<1>{}; + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); // TODO + static auto MakeSrc2dDescriptor(const std::vector& inLengths, const std::vector& inStrides, int kBlockSize, int numBlockTileIteration) { - constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t numSrcDim = Rank; - static constexpr bool reduceAllDim = (NumInvariantDim == 0); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); @@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization - static auto MakeMeanVarDescriptor_M_K(index_t M, index_t K) + static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K) { const auto grid_desc_m_k = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); @@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization - static auto MakeCountDescriptor_M_K(index_t M, index_t K) + static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K) { const auto grid_desc_m_k = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); } + static auto MakeSaveMeanInvStdDescriptor_M(const std::vector& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using Kernel1MeanVarGridDesc_M_KBlock = - decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); using Kernel2MeanVarGridDesc_M_KBlock = - decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); using Kernel2CountGridDesc_M_KBlock = - decltype(MakeCountDescriptor_M_K, 1, 1>(1, 1)); + decltype(MakeWorkspaceCountDescriptor_M_K, 1, 1>(1, 1)); + + using SaveMeanInvStdGridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); using GridwiseWelford = GridwiseNormalizationSplitK1st; using GridwiseWelfordNormalization = - GridwiseNormalizationSplitK2nd; + YDstVectorSize, + SaveMeanInvStdDstVectorSize>; struct Argument : public BaseArgument { @@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, YElementwiseOperation y_elementwise_op, double epsilon, const XDataType* p_x, const GammaDataType* p_gamma, const BetaDataType* p_beta, - YDataType* p_y) + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) : p_x_(p_x), p_gamma_(p_gamma), p_beta_(p_beta), p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), p_workspace_mean_{nullptr}, p_workspace_var_{nullptr}, p_workspace_count_{nullptr}, @@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization(yStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); @@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization, M_BlockTileSize, 1>(MRaw_, - kGridSize_); + MakeWorkspaceMeanVarDescriptor_M_K, M_BlockTileSize, 1>( + MRaw_, kGridSize_); kernel2_mean_var_grid_desc_m_kblock_ = - MakeMeanVarDescriptor_M_K, - M_BlockTileSize, - K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + MakeWorkspaceMeanVarDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); kernel2_count_grid_desc_m_kblock_ = - MakeCountDescriptor_M_K, - M_BlockTileSize, - K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + MakeWorkspaceCountDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; } ComputeDataType epsilon_; @@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization gammaStrides_; std::vector betaStrides_; std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; YElementwiseOperation y_elementwise_op_; @@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization; auto kernel2 = kernel_normalizationSplitK2nd; + SrcGridDesc_M_K, + SaveMeanInvStdGridDesc_M>; float avg_time = 0; - avg_time += launch_and_time_kernel(stream_config, - kernel1, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.x_grid_desc_m_k_, - arg.kernel1_mean_var_grid_desc_m_kblock_, - arg.numBlockTileIteration_, - arg.p_x_, - static_cast(arg.p_workspace_mean_), - static_cast(arg.p_workspace_var_), - static_cast(arg.p_workspace_count_)); + avg_time += launch_and_time_kernel( + stream_config, + kernel1, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.kernel1_mean_var_grid_desc_m_kblock_, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_)); - avg_time += launch_and_time_kernel(stream_config, - kernel2, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.kernel2_mean_var_grid_desc_m_kblock_, - arg.kernel2_count_grid_desc_m_kblock_, - arg.x_grid_desc_m_k_, - arg.gamma_grid_desc_m_k_, - arg.beta_grid_desc_m_k_, - arg.y_grid_desc_m_k_, - arg.numMeanVarCountIteration_, - arg.numBlockTileIteration_, - arg.kGridSize_, - arg.epsilon_, - static_cast(arg.p_workspace_mean_), - static_cast(arg.p_workspace_var_), - static_cast(arg.p_workspace_count_), - arg.p_x_, - arg.p_gamma_, - arg.p_beta_, - arg.p_y_, - arg.y_elementwise_op_); + avg_time += launch_and_time_kernel( + stream_config, + kernel2, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.kernel2_mean_var_grid_desc_m_kblock_, + arg.kernel2_count_grid_desc_m_kblock_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.save_mean_grid_desc_m_, + arg.save_inv_std_grid_desc_m_, + arg.numMeanVarCountIteration_, + arg.numBlockTileIteration_, + arg.kGridSize_, + arg.epsilon_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.p_saveMean_, + arg.p_saveInvStd_, + arg.y_elementwise_op_); return avg_time; }; @@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationMRaw_ * pArg_->kGridSize_; // workspace for welford intermediate mean - workspace_size += welford_size * sizeof(MeanVarDataType) + 64; + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; // workspace for welford intermediate variance - workspace_size += welford_size * sizeof(MeanVarDataType) + 64; + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; // workspace for welford intermediate count workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; @@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationp_workspace_mean_ = static_cast(pArg_->p_workspace_); - index_t mean_space_sz = welford_size * sizeof(MeanVarDataType); + index_t mean_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); // setup buffer used for intermediate welford varirance pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; - index_t variance_space_sz = welford_size * sizeof(MeanVarDataType); + index_t variance_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); // setup buffer used for intermediate welford count @@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization(p_arg); - constexpr index_t NumInvariantDim = Rank - NumReduceDim; - if constexpr(XYVectorDim == 0) { if constexpr(NumInvariantDim == 0) @@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationxStrides_[NumInvariantDim - 1] != 1) return false; - if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) return false; - if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) return false; }; } @@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationbetaStrides_[NumInvariantDim - 1] != 1) return false; - if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) return false; } else // if fastest dim is reduced @@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationkGridSize_ <= 1) return false; + if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + return true; }; @@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, double epsilon, const void* p_x, @@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization(lengths, xStrides, gammaStrides, betaStrides, yStrides, + saveMeanStrides, + saveInvStdStrides, reduceDims, y_elementwise_op, epsilon, static_cast(p_x), static_cast(p_gamma), static_cast(p_beta), - static_cast(p_y)); + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); }; std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp index c3f122106d..1bee1b93f9 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -18,9 +18,11 @@ template struct GridwiseNormalizationNaiveVariance_mk_to_mk { @@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = @@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk reduce::Add, true>; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { // LDS @@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer& var_thread_buf = mean_square_thread_buf; + StaticBuffer& + inv_std_thread_buf = mean_square_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); @@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // E(x), E[x^2], var(x) // FIXME: Should not hack the transform from deviceOP - int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + ComputeDataType reduce_length = type_convert( + x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); @@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // var(x) = E[x^2] - E[x]^2 var_thread_buf(I) = mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); }); + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma & beta y_thread_buf(iK0)(Number{}) = @@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // var(x) = E[x^2] - E[x]^2 var_thread_buf(I) = mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); }); + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; @@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp index e50fb98133..8157b4fbc3 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp @@ -12,31 +12,42 @@ template -__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, - const GridDesc_M_K gamma_grid_desc_m_k, - const GridDesc_M_K beta_grid_desc_m_k, - const GridDesc_M_K y_grid_desc_m_k, - index_t num_k_block_tile_iteration, - ComputeDataType epsilon, - const XDataType* const __restrict__ p_x_global, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const YElementwiseOperation y_elementwise_op) + typename GridDesc_M_K, + typename GridDesc_M> +__global__ void +kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K beta_grid_desc_m_k, + const GridDesc_M_K y_grid_desc_m_k, + const GridDesc_M save_mean_grid_desc_m, + const GridDesc_M save_inv_std_grid_desc_m, + index_t num_k_block_tile_iteration, + ComputeDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) { GridwiseReduction::Run(x_grid_desc_m_k, gamma_grid_desc_m_k, beta_grid_desc_m_k, y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, num_k_block_tile_iteration, epsilon, p_x_global, p_gamma_global, p_beta_global, p_y_global, + p_save_mean_global, + p_save_inv_std_global, y_elementwise_op); }; @@ -44,9 +55,11 @@ template auto NormalizationKernelSelector(bool isSweepOnce) { @@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, GridDesc_M_K, + GridDesc_M, BlockSize, MThreadClusterSize, KThreadClusterSize, @@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) BetaSrcVectorSize, YDstVectorDim, YDstVectorSize, + SaveMeanInvStdDstVectorSize, false>; using GridwiseNormalizationSweepOnceNaive = GridwiseNormalizationNaiveVariance_mk_to_mk; using GridwiseNormalizationGenericWelford = GridwiseNormalizationWelfordVariance_mk_to_mk; using GridwiseNormalizationSweepOnceWelford = GridwiseNormalizationWelfordVariance_mk_to_mk; if constexpr(UseWelford) @@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, - GridDesc_M_K> + GridDesc_M_K, + GridDesc_M> : kernel_normalization; + GridDesc_M_K, + GridDesc_M>; } else { @@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, - GridDesc_M_K> + GridDesc_M_K, + GridDesc_M> : kernel_normalization; + GridDesc_M_K, + GridDesc_M>; } } diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp index 136ac94e7f..9e380f9638 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -17,11 +17,13 @@ template + index_t YDstVectorSize, + index_t SaveMeanInvStdDstVectorSize> struct GridwiseNormalizationSplitK2nd { static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || @@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadBufferLengths_M_1 = Sequence; static constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); @@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, @@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { // Thread/Block id @@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + // VGPR StaticBuffer in_mean_thread_buf; @@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd var_thread_buf; StaticBuffer welford_count_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; auto x_thread_buf = generate_tuple( [&](auto) { @@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + // step1: Merge mean and variance constexpr auto mean_var_count_thread_copy_step_I0_k = make_multi_index(I0, KThreadClusterSize); @@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd BlockwiseWelford::Run( mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); + + inv_std_thread_buf(I) = + type_convert(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon); }); - // step2: normalization + // step2: save mean and inverse std for backward (optional) + if(block_k_cluster_id == 0 && thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // step3: normalization constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); for(index_t k = 0; k < num_k_block_tile_iteration; ++k) @@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp index ff9712276c..15b412fba4 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -16,9 +16,11 @@ template struct GridwiseNormalizationWelfordVariance_mk_to_mk { @@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = @@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ThreadClusterLengths_M_K, ThreadClusterArrangeOrder>; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { - auto y_global_val_buf = make_dynamic_buffer( - p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); - auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer var_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); @@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const auto beta_global_val_buf = make_dynamic_buffer( p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + auto threadwise_welford = ThreadwiseWelford(); threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); @@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk int count = threadwise_welford.cur_count_; BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); }); + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma & beta y_thread_buf(iK0)(Number{}) = @@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk int count = threadwise_welford.cur_count_; BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); }); + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; @@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp index 6a48528c54..62920fc03e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp @@ -20,8 +20,9 @@ template + typename SaveMeanInvStdDataType, + typename ComputeDataType, + typename YElementwiseOperation> struct ReferenceGroupnorm : public device::BaseOperator { // x = [N, H, W, G, C] @@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator const Tensor& gamma, const Tensor& beta, Tensor& y, - AccElementwiseOperation acc_elementwise_op, + Tensor& save_mean, + Tensor& save_inv_std, + YElementwiseOperation y_elementwise_op, const std::vector lengths, - AccDataType epsilon) + ComputeDataType epsilon) : x_(x), gamma_(gamma), beta_(beta), y_(y), - acc_elementwise_op_(acc_elementwise_op), + save_mean_(save_mean), + save_inv_std_(save_inv_std), + y_elementwise_op_(y_elementwise_op), lengths_(lengths), epsilon_(epsilon) { @@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator const Tensor gamma_; const Tensor beta_; Tensor& y_; - AccElementwiseOperation acc_elementwise_op_; + Tensor& save_mean_; + Tensor& save_inv_std_; + YElementwiseOperation y_elementwise_op_; std::vector lengths_; - AccDataType epsilon_; + ComputeDataType epsilon_; }; // Invoker @@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator int G = arg.lengths_[3]; int C = arg.lengths_[4]; - Tensor mean({N, G}); - Tensor var({N, G}); + Tensor mean({N, G}); + Tensor var({N, G}); // Compute mean & var in [H, W, C] by Welford Algorithm // TODO - parallel for each HWC @@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator { for(int g = 0; g < G; ++g) { - AccDataType mean_val = type_convert(0.0f); - AccDataType var_val = type_convert(0.0f); - int32_t curr_count = 0; + ComputeDataType mean_val = type_convert(0.0f); + ComputeDataType var_val = type_convert(0.0f); + int32_t curr_count = 0; for(int h = 0; h < H; ++h) { @@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator for(int c = 0; c < C; ++c) { curr_count++; - AccDataType x = type_convert(arg.x_(n, h, w, g, c)); - AccDataType delta = x - mean_val; + ComputeDataType x = + type_convert(arg.x_(n, h, w, g, c)); + ComputeDataType delta = x - mean_val; mean_val += delta / curr_count; - AccDataType delta2 = x - mean_val; + ComputeDataType delta2 = x - mean_val; var_val += delta * delta2; } } @@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator mean(n, g) = mean_val; var(n, g) = var_val / curr_count; + + arg.save_mean_(n, g) = ck::type_convert(mean(n, g)); + + ComputeDataType divisor = + static_cast(1) / ck::math::sqrt(var(n, g) + arg.epsilon_); + arg.save_inv_std_(n, g) = ck::type_convert(divisor); } } @@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator { for(int c = 0; c < C; ++c) { - AccDataType x = type_convert(arg.x_(n, h, w, g, c)); - AccDataType gamma = type_convert(arg.gamma_(g, c)); - AccDataType beta = type_convert(arg.beta_(g, c)); - AccDataType mean_val = type_convert(mean(n, g)); - AccDataType var_val = type_convert(var(n, g)); - AccDataType y = gamma * (x - mean_val) / - ck::math::sqrt(arg.epsilon_ + var_val) + - beta; - arg.acc_elementwise_op_(y, y); + ComputeDataType x = + type_convert(arg.x_(n, h, w, g, c)); + ComputeDataType gamma = + type_convert(arg.gamma_(g, c)); + ComputeDataType beta = + type_convert(arg.beta_(g, c)); + ComputeDataType mean_val = + type_convert(mean(n, g)); + ComputeDataType var_val = type_convert(var(n, g)); + ComputeDataType y = gamma * (x - mean_val) / + ck::math::sqrt(arg.epsilon_ + var_val) + + beta; + arg.y_elementwise_op_(y, y); arg.y_(n, h, w, g, c) = type_convert(y); } } @@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator const Tensor& gamma, const Tensor& beta, Tensor& y, - AccElementwiseOperation acc_elementwise_op, + Tensor& save_mean, + Tensor& save_inv_std, + YElementwiseOperation y_elementwise_op, const std::vector lengths, - AccDataType epsilon) + ComputeDataType epsilon) { - return Argument{x, gamma, beta, y, acc_elementwise_op, lengths, epsilon}; + return Argument{ + x, gamma, beta, y, save_mean, save_inv_std, y_elementwise_op, lengths, epsilon}; } static auto MakeInvoker() { return Invoker{}; } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp index 9994a2f9f7..444ae970c1 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp @@ -20,8 +20,9 @@ template struct ReferenceLayernorm : public device::BaseOperator @@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator const Tensor& gamma_n, const Tensor& beta_n, Tensor& y_m_n, - AccElementwiseOperation acc_elementwise_op, + Tensor& save_mean_m, + Tensor& save_inv_std_m, + YElementwiseOperation y_elementwise_op, const std::vector lengths, const std::vector reduceDims, - AccDataType epsilon) + ComputeDataType epsilon) : x_m_n_(x_m_n), gamma_n_(gamma_n), beta_n_(beta_n), y_m_n_(y_m_n), - acc_elementwise_op_(acc_elementwise_op), + save_mean_m_(save_mean_m), + save_inv_std_m_(save_inv_std_m), + y_elementwise_op_(y_elementwise_op), lengths_(lengths), reduceDims_(reduceDims), epsilon_(epsilon) @@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator const Tensor gamma_n_; const Tensor beta_n_; Tensor& y_m_n_; - AccElementwiseOperation acc_elementwise_op_; + Tensor& save_mean_m_; + Tensor& save_inv_std_m_; + YElementwiseOperation y_elementwise_op_; std::vector lengths_; std::vector reduceDims_; - AccDataType epsilon_; + ComputeDataType epsilon_; }; // Invoker @@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator int M = arg.lengths_[0]; int N = arg.lengths_[1]; - Tensor mean({M}); - Tensor var({M}); + Tensor mean({M}); + Tensor var({M}); for(int m = 0; m < M; ++m) { @@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator for(int n = 0; n < N; ++n) { - auto x_val = ck::type_convert(arg.x_m_n_(m, n)); + auto x_val = ck::type_convert(arg.x_m_n_(m, n)); mean(m) += x_val; var(m) += x_val * x_val; } @@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator for(int m = 0; m < M; ++m) { - AccDataType divisor = - static_cast(1) / ck::math::sqrt(var(m) + arg.epsilon_); + ComputeDataType divisor = + static_cast(1) / ck::math::sqrt(var(m) + arg.epsilon_); for(int n = 0; n < N; ++n) { - auto x_val = ck::type_convert(arg.x_m_n_(m, n)); - auto y_val = (x_val - mean(m)) * divisor; - y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); - arg.acc_elementwise_op_(y_val, y_val); + auto x_val = ck::type_convert(arg.x_m_n_(m, n)); + auto gamma_val = ck::type_convert(arg.gamma_n_(n)); + auto beta_val = ck::type_convert(arg.beta_n_(n)); + auto y_val = (x_val - mean(m)) * divisor; + y_val = (y_val * gamma_val) + beta_val; + arg.y_elementwise_op_(y_val, y_val); arg.y_m_n_(m, n) = ck::type_convert(y_val); } + arg.save_mean_m_(m) = ck::type_convert(mean(m)); + arg.save_inv_std_m_(m) = ck::type_convert(divisor); } return 0; @@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator const Tensor& gamma_n, const Tensor& beta_n, Tensor& y_m_n, - AccElementwiseOperation acc_elementwise_op, + Tensor& save_mean_m, + Tensor& save_inv_std_m, + YElementwiseOperation y_elementwise_op, const std::vector lengths, const std::vector reduceDims, - AccDataType epsilon) + ComputeDataType epsilon) { - return Argument{ - x_m_n, gamma_n, beta_n, y_m_n, acc_elementwise_op, lengths, reduceDims, epsilon}; + return Argument{x_m_n, + gamma_n, + beta_n, + y_m_n, + save_mean_m, + save_inv_std_m, + y_elementwise_op, + lengths, + reduceDims, + epsilon}; } static auto MakeInvoker() { return Invoker{}; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp index 8e90a7ea98..229de41b5e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp @@ -19,13 +19,13 @@ namespace instance { #ifdef CK_ENABLE_FP16 // FP16 void add_device_normalization_rank_2_1_f16_instances( - std::vector>>&); + std::vector>>&); void add_device_normalization_rank_4_3_f16_instances( - std::vector>>&); + std::vector>>&); void add_device_normalization_rank_5_3_f16_instances( - std::vector>>&); + std::vector>>&); #endif #ifdef CK_ENABLE_FP32 // FP32 @@ -42,14 +42,15 @@ template struct DeviceOperationInstanceFactory> @@ -57,8 +58,8 @@ struct DeviceOperationInstanceFactory; @@ -68,7 +69,8 @@ struct DeviceOperationInstanceFactory> op_ptrs; #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 2 && NumReduceDim == 1) { @@ -86,7 +88,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 2 && NumReduceDim == 1) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp index 2391775299..ae9c09cacb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp @@ -19,7 +19,7 @@ namespace instance { // FP16 void add_device_normalization_rank_5_3_swish_f16_instances( - std::vector>>&); + std::vector>>&); // FP32 void add_device_normalization_rank_5_3_swish_f32_instances( @@ -27,20 +27,21 @@ void add_device_normalization_rank_5_3_swish_f32_instances( // [x, gamma, beta, y] = [f16, f32, f32, f16] void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( - std::vector>>&); + std::vector>>&); template struct DeviceOperationInstanceFactory< ck::tensor_operation::device::DeviceNormalization> @@ -48,8 +49,8 @@ struct DeviceOperationInstanceFactory< using DeviceOp = DeviceNormalization; @@ -59,7 +60,8 @@ struct DeviceOperationInstanceFactory< std::vector> op_ptrs; if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 5 && NumReduceDim == 3) { @@ -67,7 +69,8 @@ struct DeviceOperationInstanceFactory< } } else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 5 && NumReduceDim == 3) { @@ -75,7 +78,8 @@ struct DeviceOperationInstanceFactory< } } else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 5 && NumReduceDim == 3) { diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp index 762da1c6ae..439e724199 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Pass = ck::tensor_operation::element_wise::PassThrough; void add_device_normalization_rank_5_3_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp index aa662b7dfe..5f42d073ff 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Swish = ck::tensor_operation::element_wise::Swish; void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp index bc5cd801ae..63aea024da 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Swish = ck::tensor_operation::element_wise::Swish; void add_device_normalization_rank_5_3_swish_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp index 0d235f1fa7..e15ff4b6d0 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Pass = ck::tensor_operation::element_wise::PassThrough; void add_device_normalization_rank_2_1_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp index 6bc3950062..4152c6ebbf 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Pass = ck::tensor_operation::element_wise::PassThrough; void add_device_normalization_rank_4_3_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp index 7aa3da8eed..488f34b4b3 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp +++ b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp @@ -22,25 +22,25 @@ template using device_normalization_f16_instances = // clang-format off std::tuple < - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl // clang-format on >; @@ -48,150 +48,150 @@ template using device_normalization_splitk_f16_instances = // clang-format off std::tuple < - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl // clang-format on >; template using device_normalization_f16_generic_instance = std::tuple< // clang-format off - DeviceNormalizationImpl + DeviceNormalizationImpl // clang-format on >; template using device_normalization_f32_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl // clang-format on >; template using device_normalization_splitk_f32_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl // clang-format on >; template using device_normalization_f32_generic_instance = std::tuple< // clang-format off - DeviceNormalizationImpl + DeviceNormalizationImpl // clang-format on >; template using device_normalization_f16_f32_f32_f16_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl // clang-format on >; template using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, // irregular size - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl, - DeviceNormalizationSplitKImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl // clang-format on >; template using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple< // clang-format off - DeviceNormalizationImpl + DeviceNormalizationImpl // clang-format on >; diff --git a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp index 1fd9c81109..ae42919db6 100644 --- a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp @@ -80,6 +80,8 @@ bool profile_elementwise_layernorm_impl(int do_verification, Tensor beta(gammaBetaLength); Tensor y(length); Tensor host_y(length); + Tensor host_save_mean({M}); + Tensor host_save_inv_std({M}); switch(init_method) { @@ -152,14 +154,23 @@ bool profile_elementwise_layernorm_impl(int do_verification, BetaDataType, YDataType, AccDataType, + AccDataType, PassThrough, Rank, NumReduceDim>; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + PassThrough{}, + {M, N}, + {1}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); } diff --git a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp index 4c3d0a0450..46591a3525 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp @@ -66,12 +66,15 @@ void host_gemm_layernorm(Tensor& h_m_n, BetaDataType, HDataType, AccDataType, + AccDataType, HElementOp, 2, 1>; Tensor e_m_n(HostTensorDescriptor{M, N}); Tensor c_m_n(HostTensorDescriptor{M, N}); + Tensor save_mean({M}); + Tensor save_inv_std({M}); auto ref_gemm = ReferenceGemm{}; auto ref_gemm_invoker = ref_gemm.MakeInvoker(); @@ -97,7 +100,7 @@ void host_gemm_layernorm(Tensor& h_m_n, auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_argument = ref_layernorm.MakeArgument( - e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); + e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon); ref_layernorm_invoker.Run(ref_layernorm_argument); } diff --git a/profiler/include/profiler/profile_groupnorm_impl.hpp b/profiler/include/profiler/profile_groupnorm_impl.hpp index f88ba8453c..4715853d2a 100644 --- a/profiler/include/profiler/profile_groupnorm_impl.hpp +++ b/profiler/include/profiler/profile_groupnorm_impl.hpp @@ -21,8 +21,10 @@ namespace profiler { template + typename ComputeDataType, + typename YDataType, + typename SaveMeanInvStdDataType, + bool SaveMeanInvStd> bool profile_groupnorm_impl(int do_verification, int init_method, bool do_log, @@ -34,6 +36,7 @@ bool profile_groupnorm_impl(int do_verification, if(length.size() != 5) return false; + index_t N = length[0]; index_t G = length[3]; index_t C = length[4]; @@ -45,7 +48,14 @@ bool profile_groupnorm_impl(int do_verification, Tensor gamma(gammaBetaLength); Tensor beta(gammaBetaLength); Tensor y(length); + Tensor save_mean({N, G}); + Tensor save_inv_std({N, G}); + Tensor host_y(length); + Tensor host_save_mean({N, G}); + Tensor host_save_inv_std({N, G}); + + std::vector strideSaveMeanInvStd = {1}; switch(init_method) { @@ -69,6 +79,9 @@ bool profile_groupnorm_impl(int do_verification, DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); x_dev.ToDevice(x.mData.data()); gamma_dev.ToDevice(gamma.mData.data()); @@ -78,8 +91,8 @@ bool profile_groupnorm_impl(int do_verification, using DeviceOp = ck::tensor_operation::device::DeviceNormalization; @@ -97,38 +110,70 @@ bool profile_groupnorm_impl(int do_verification, if(do_verification) { - using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm; + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnorm; ReferenceInstance ref; - auto ref_argument = ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, length, 1e-6); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument( + x, gamma, beta, host_y, host_save_mean, host_save_inv_std, PassThrough{}, length, 1e-6); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); } int num_kernel = 0; + auto f_get_argument = [&](auto& inst_ptr) { + if constexpr(SaveMeanInvStd) + return inst_ptr->MakeArgumentPointer( + length, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + gammaBetaStride, + gammaBetaStride, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_inv_std.mDesc.GetStrides().begin(), + save_inv_std.mDesc.GetStrides().end()}, + reduce_dim, + 1e-6, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), + PassThrough{}); + else + return inst_ptr->MakeArgumentPointer( + length, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + gammaBetaStride, + gammaBetaStride, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_inv_std.mDesc.GetStrides().begin(), + save_inv_std.mDesc.GetStrides().end()}, + reduce_dim, + 1e-6, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + nullptr, + nullptr, + PassThrough{}); + }; + for(auto& inst_ptr : instance_ptrs) { - auto argument_ptr = inst_ptr->MakeArgumentPointer( - length, - std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, - gammaBetaStride, - gammaBetaStride, - std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, - reduce_dim, - 1e-6, - x_dev.GetDeviceBuffer(), - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - y_dev.GetDeviceBuffer(), - nullptr, - nullptr, - PassThrough{}); + auto argument_ptr = f_get_argument(inst_ptr); if(inst_ptr->IsSupportedArgument(argument_ptr.get())) { @@ -152,6 +197,10 @@ bool profile_groupnorm_impl(int do_verification, beta.mDesc.GetElementSize() * sizeof(BetaDataType) + y.mDesc.GetElementSize() * sizeof(YDataType); + if constexpr(SaveMeanInvStd) + num_bytes += save_mean.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType) + + save_inv_std.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType); + float gb_per_sec = num_bytes / 1.E6 / avg_time; if(time_kernel) @@ -168,9 +217,22 @@ bool profile_groupnorm_impl(int do_verification, if(do_verification) { y_dev.FromDevice(y.mData.data()); - bool pass = ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); + if constexpr(SaveMeanInvStd) + { + save_mean_dev.FromDevice(save_mean.mData.data()); + pass &= ck::utils::check_err( + save_mean.mData, host_save_mean.mData, "Error: Incorrect results", 1e-3, 1e-3); + + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err(save_inv_std.mData, + host_save_inv_std.mData, + "Error: Incorrect results", + 1e-3, + 1e-3); + } + if(do_log) { LogRangeAsType(std::cout << "x : ", x.mData, ",") << std::endl; diff --git a/profiler/include/profiler/profile_layernorm_impl.hpp b/profiler/include/profiler/profile_layernorm_impl.hpp index f969646c2f..7c214af019 100644 --- a/profiler/include/profiler/profile_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_layernorm_impl.hpp @@ -21,6 +21,8 @@ template bool profile_layernorm_impl(int do_verification, int init_method, @@ -43,13 +45,19 @@ bool profile_layernorm_impl(int do_verification, Tensor gamma(reduce_length); Tensor beta(reduce_length); Tensor y(length); + Tensor save_mean({length[0]}); + Tensor save_inv_std({length[0]}); Tensor host_y(length); + Tensor host_save_mean({length[0]}); + Tensor host_save_inv_std({length[0]}); std::vector strideXY = std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; std::vector strideGammaBeta = strideXY; strideGammaBeta[0] = 0; + std::vector strideSaveMeanInvStd = {1}; + switch(init_method) { case 0: @@ -75,6 +83,9 @@ bool profile_layernorm_impl(int do_verification, DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); x_dev.ToDevice(x.mData.data()); gamma_dev.ToDevice(gamma.mData.data()); @@ -86,8 +97,8 @@ bool profile_layernorm_impl(int do_verification, using DeviceOp = ck::tensor_operation::device::DeviceNormalization; @@ -105,40 +116,74 @@ bool profile_layernorm_impl(int do_verification, if(do_verification) { - using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernorm; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, length, reduce_dim, 1e-4); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + PassThrough{}, + length, + reduce_dim, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); } int num_kernel = 0; + auto f_get_argument = [&](auto& inst_ptr) { + if constexpr(SaveMeanInvStd) + return inst_ptr->MakeArgumentPointer(length, + strideXY, + strideGammaBeta, + strideGammaBeta, + strideXY, + strideSaveMeanInvStd, + strideSaveMeanInvStd, + reduce_dim, + 1e-4, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), + PassThrough{}); + else + return inst_ptr->MakeArgumentPointer(length, + strideXY, + strideGammaBeta, + strideGammaBeta, + strideXY, + strideSaveMeanInvStd, + strideSaveMeanInvStd, + reduce_dim, + 1e-4, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + nullptr, + nullptr, + PassThrough{}); + }; + for(auto& inst_ptr : instance_ptrs) { - auto argument_ptr = inst_ptr->MakeArgumentPointer(length, - strideXY, - strideGammaBeta, - strideGammaBeta, - strideXY, - reduce_dim, - 1e-4, - x_dev.GetDeviceBuffer(), - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - y_dev.GetDeviceBuffer(), - nullptr, - nullptr, - PassThrough{}); + auto argument_ptr = f_get_argument(inst_ptr); if(inst_ptr->IsSupportedArgument(argument_ptr.get())) { @@ -168,6 +213,10 @@ bool profile_layernorm_impl(int do_verification, beta.mDesc.GetElementSize() * sizeof(BetaDataType) + y.mDesc.GetElementSize() * sizeof(YDataType); + if constexpr(SaveMeanInvStd) + num_bytes += save_mean.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType) + + save_inv_std.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType); + float gb_per_sec = num_bytes / 1.E6 / avg_time; if(time_kernel) @@ -184,10 +233,23 @@ bool profile_layernorm_impl(int do_verification, if(do_verification) { y_dev.FromDevice(y.mData.data()); - bool pass = ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); + if constexpr(SaveMeanInvStd) + { + save_mean_dev.FromDevice(save_mean.mData.data()); + pass &= ck::utils::check_err( + save_mean.mData, host_save_mean.mData, "Error: Incorrect results", 1e-3, 1e-3); + + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err(save_inv_std.mData, + host_save_inv_std.mData, + "Error: Incorrect results", + 1e-3, + 1e-3); + } + if(do_log) { LogRangeAsType(std::cout << "x : ", x.mData, ",") << std::endl; diff --git a/profiler/src/profile_groupnorm.cpp b/profiler/src/profile_groupnorm.cpp index d55784ff0a..079f6f0db7 100644 --- a/profiler/src/profile_groupnorm.cpp +++ b/profiler/src/profile_groupnorm.cpp @@ -93,12 +93,12 @@ int profile_groupnorm(int argc, char* argv[]) if(data_type == ck::DataTypeEnum::Float) { - ck::profiler::profile_groupnorm_impl( + ck::profiler::profile_groupnorm_impl( do_verification, init_method, do_log, time_kernel, length); } else if(data_type == ck::DataTypeEnum::Half) { - ck::profiler::profile_groupnorm_impl( + ck::profiler::profile_groupnorm_impl( do_verification, init_method, do_log, time_kernel, length); } else diff --git a/profiler/src/profile_layernorm.cpp b/profiler/src/profile_layernorm.cpp index 7bf210e678..fdeaa036b2 100644 --- a/profiler/src/profile_layernorm.cpp +++ b/profiler/src/profile_layernorm.cpp @@ -82,12 +82,12 @@ int profile_layernorm(int argc, char* argv[]) if(data_type == ck::DataTypeEnum::Half) { - ck::profiler::profile_layernorm_impl( + ck::profiler::profile_layernorm_impl( do_verification, init_method, do_log, time_kernel, length); } else if(data_type == ck::DataTypeEnum::Float) { - ck::profiler::profile_layernorm_impl( + ck::profiler::profile_layernorm_impl( do_verification, init_method, do_log, time_kernel, length); } else diff --git a/test/normalization/test_groupnorm_fp16.cpp b/test/normalization/test_groupnorm_fp16.cpp index 325ea75fe5..67387ad40b 100644 --- a/test/normalization/test_groupnorm_fp16.cpp +++ b/test/normalization/test_groupnorm_fp16.cpp @@ -12,11 +12,12 @@ template class TestGroupnorm : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using ComputeDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; + using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>; void Run() { @@ -37,7 +38,9 @@ class TestGroupnorm : public ::testing::Test GammaDataType, BetaDataType, ComputeDataType, - YDataType>(true, 2, false, false, length); + YDataType, + SaveMeanInvStdDataType, + true>(true, 2, false, false, length); EXPECT_TRUE(success); } } @@ -45,7 +48,7 @@ class TestGroupnorm : public ::testing::Test using KernelTypes = ::testing::Types< // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); } diff --git a/test/normalization/test_groupnorm_fp32.cpp b/test/normalization/test_groupnorm_fp32.cpp index ec88442fc0..136079f104 100644 --- a/test/normalization/test_groupnorm_fp32.cpp +++ b/test/normalization/test_groupnorm_fp32.cpp @@ -12,11 +12,12 @@ template class TestGroupnorm : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using ComputeDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; + using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>; void Run() { @@ -35,7 +36,9 @@ class TestGroupnorm : public ::testing::Test GammaDataType, BetaDataType, ComputeDataType, - YDataType>(true, 2, false, false, length); + YDataType, + SaveMeanInvStdDataType, + true>(true, 2, false, false, length); EXPECT_TRUE(success); } } @@ -43,7 +46,7 @@ class TestGroupnorm : public ::testing::Test using KernelTypes = ::testing::Types< // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST(TestGroupnorm, Test_FP32) { this->Run(); } diff --git a/test/normalization/test_layernorm2d_fp16.cpp b/test/normalization/test_layernorm2d_fp16.cpp index 2222740fcc..54bab25257 100644 --- a/test/normalization/test_layernorm2d_fp16.cpp +++ b/test/normalization/test_layernorm2d_fp16.cpp @@ -12,11 +12,12 @@ template class TestLayernorm2d : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using ComputeDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; + using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>; void Run() { @@ -31,6 +32,8 @@ class TestLayernorm2d : public ::testing::Test BetaDataType, ComputeDataType, YDataType, + SaveMeanInvStdDataType, + true, 2>(true, 2, false, false, length); EXPECT_TRUE(success); } @@ -39,7 +42,7 @@ class TestLayernorm2d : public ::testing::Test using KernelTypes = ::testing::Types< // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + std::tuple>; TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST(TestLayernorm2d, Test_FP16) { this->Run(); } diff --git a/test/normalization/test_layernorm2d_fp32.cpp b/test/normalization/test_layernorm2d_fp32.cpp index 30fbe06c60..ee9646a4d5 100644 --- a/test/normalization/test_layernorm2d_fp32.cpp +++ b/test/normalization/test_layernorm2d_fp32.cpp @@ -12,11 +12,12 @@ template class TestLayernorm2d : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using ComputeDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; + using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>; void Run() { @@ -31,6 +32,8 @@ class TestLayernorm2d : public ::testing::Test BetaDataType, ComputeDataType, YDataType, + SaveMeanInvStdDataType, + true, 2>(true, 2, false, false, length); EXPECT_TRUE(success); } @@ -39,7 +42,7 @@ class TestLayernorm2d : public ::testing::Test using KernelTypes = ::testing::Types< // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + std::tuple>; TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST(TestLayernorm2d, Test_FP32) { this->Run(); }