Layernorm and groupnorm support to save mean and inverse std in forward (#929)

* save mean and inverse std in normalization

* Save mean and inverse std in splitK

* Vector save mean and inv std

* Modify instance for save mean and std

* simplify the layernorm example

* Save mean and std in groupnorm example

* Save mean and inv std in ckProfiler and test

* Remove compute data type from base class

* Save mean and inv std in client example

* Add changelog

* clang format

* Fix compile error

* Refine naming

* Avoid error in bf16

* revert changelog
This commit is contained in:
rocking
2023-10-19 07:36:29 +08:00
committed by GitHub
parent 58338bb203
commit 3696fe1c76
38 changed files with 1393 additions and 544 deletions

View File

@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
using XDataType = ck::half_t;
using GammaDataType = float;
using BetaDataType = float;
using YDataType = ck::half_t;
using ComputeDataType = float;
using Swish = ck::tensor_operation::element_wise::Swish;
using XDataType = ck::half_t;
using GammaDataType = float;
using BetaDataType = float;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using Swish = ck::tensor_operation::element_wise::Swish;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
@@ -49,19 +51,24 @@ int main(int argc, char* argv[])
std::size_t xy_size = N * H * W * G * C;
std::size_t gamma_beta_size = G * C;
std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1};
std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1};
std::vector<ck::index_t> save_mean_inv_std_strides = {G, 1};
SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
#endif
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
Swish,
Rank,
NumReduceDim>;
@@ -75,19 +82,26 @@ int main(int argc, char* argv[])
const auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr =
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
save_mean_inv_std_strides, // save_mean Strides
save_mean_inv_std_strides, // save_inv_std Strides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
@@ -107,21 +121,29 @@ int main(int argc, char* argv[])
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
nullptr,
nullptr,
Swish{});
auto& op_ptr = op_ptrs[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
save_mean_inv_std_strides, // save_mean Strides
save_mean_inv_std_strides, // save_inv_std Strides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
@@ -139,6 +161,10 @@ int main(int argc, char* argv[])
sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size +
sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size;
#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * N * G * 2;
#endif
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
@@ -169,20 +195,28 @@ int main(int argc, char* argv[])
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
nullptr,
nullptr,
Swish{});
auto argument_ptr =
op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
save_mean_inv_std_strides, // save_mean Strides
save_mean_inv_std_strides, // save_inv_std Strides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();