mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Batchnorm-forward and Batchnorm-infer Implemented using generic kernels (#320)
* Implement multiple-reduction in one kernel (kernels, device ops, examples) * Add generic elementwise kernel and device interface * Add generator for normal-distributed data initialization * Add host refer implementation of batchnorm-forward and batchnorm-infer * Add examples for implementing batchnorm-forward and batchnorm-infer using generic kernels * Remove un-needed including in batchnorm example * Renaming generic_elementwise to elementiwise in kernel and device classes/functions * Change in gemm_layernorm examples to use DeviceElementwise instead of Device5AryElementwise * Change in exampe 19_binary_elementwise to use DeviceElementwise instead of DeviceBinaryElementwise * Change in device_cgemm_4gemm_xdl_cshuffle.hpp to use kernel_elementwise instead of kernel_binary_elementwise * Add DeviceElementwiseBase and use it in device_normalize_instance.cpp * Removing and renaming files * Update to synchronize gemm_layernorm client example to the generic element-wise device op API * Update to synchronize with the latest headers directory and HostTensorDescriptor interface renaming * Merge two static member functions in device_elementwise.hpp * Remove unary_elementwise_1d kernel and device
This commit is contained in:
@@ -128,11 +128,14 @@ bool RunDeviceNormalize2D(normalize_op_ptr& p_op,
|
||||
std::array<void*, 1> output = {p_y};
|
||||
auto normalize_functor = ck::tensor_operation::element_wise::Normalize{};
|
||||
|
||||
auto argument_ptr = p_op->MakeArgumentPointer(input,
|
||||
std::array<ck::index_t, 2> xyLengths = {M, N};
|
||||
std::array<ck::index_t, 2> xyStrides = {StrideX, 1};
|
||||
|
||||
auto argument_ptr = p_op->MakeArgumentPointer(xyLengths,
|
||||
{xyStrides, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
|
||||
{xyStrides},
|
||||
input,
|
||||
output,
|
||||
{M, N},
|
||||
{{StrideX, 1}, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
|
||||
{{StrideX, 1}},
|
||||
ck::tensor_operation::element_wise::Normalize{});
|
||||
|
||||
if(p_op->IsSupportedArgument(argument_ptr.get()))
|
||||
|
||||
Reference in New Issue
Block a user