Batchnorm-forward and Batchnorm-infer Implemented using generic kernels (#320)

* Implement multiple-reduction in one kernel (kernels, device ops, examples)

* Add generic elementwise kernel and device interface

* Add generator for normal-distributed data initialization

* Add host refer implementation of batchnorm-forward and batchnorm-infer

* Add examples for implementing batchnorm-forward and batchnorm-infer using generic kernels

* Remove un-needed including in batchnorm example

* Renaming generic_elementwise to elementiwise in kernel and device classes/functions

* Change in gemm_layernorm examples to use DeviceElementwise instead of Device5AryElementwise

* Change in exampe 19_binary_elementwise to use DeviceElementwise instead of DeviceBinaryElementwise

* Change in device_cgemm_4gemm_xdl_cshuffle.hpp to use kernel_elementwise instead of kernel_binary_elementwise

* Add DeviceElementwiseBase and use it in device_normalize_instance.cpp

* Removing and renaming files

* Update to synchronize gemm_layernorm client example to the generic element-wise device op API

* Update to synchronize with the latest headers directory and HostTensorDescriptor interface renaming

* Merge two static member functions in device_elementwise.hpp

* Remove unary_elementwise_1d kernel and device
This commit is contained in:
Qianfeng
2022-08-15 23:11:02 +08:00
committed by GitHub
parent 5ee304595c
commit 53ea4713af
47 changed files with 5195 additions and 1707 deletions

View File

@@ -128,11 +128,14 @@ bool RunDeviceNormalize2D(normalize_op_ptr& p_op,
std::array<void*, 1> output = {p_y};
auto normalize_functor = ck::tensor_operation::element_wise::Normalize{};
auto argument_ptr = p_op->MakeArgumentPointer(input,
std::array<ck::index_t, 2> xyLengths = {M, N};
std::array<ck::index_t, 2> xyStrides = {StrideX, 1};
auto argument_ptr = p_op->MakeArgumentPointer(xyLengths,
{xyStrides, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
{xyStrides},
input,
output,
{M, N},
{{StrideX, 1}, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
{{StrideX, 1}},
ck::tensor_operation::element_wise::Normalize{});
if(p_op->IsSupportedArgument(argument_ptr.get()))