mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Batchnorm-forward and Batchnorm-infer Implemented using generic kernels (#320)
* Implement multiple-reduction in one kernel (kernels, device ops, examples) * Add generic elementwise kernel and device interface * Add generator for normal-distributed data initialization * Add host refer implementation of batchnorm-forward and batchnorm-infer * Add examples for implementing batchnorm-forward and batchnorm-infer using generic kernels * Remove un-needed including in batchnorm example * Renaming generic_elementwise to elementiwise in kernel and device classes/functions * Change in gemm_layernorm examples to use DeviceElementwise instead of Device5AryElementwise * Change in exampe 19_binary_elementwise to use DeviceElementwise instead of DeviceBinaryElementwise * Change in device_cgemm_4gemm_xdl_cshuffle.hpp to use kernel_elementwise instead of kernel_binary_elementwise * Add DeviceElementwiseBase and use it in device_normalize_instance.cpp * Removing and renaming files * Update to synchronize gemm_layernorm client example to the generic element-wise device op API * Update to synchronize with the latest headers directory and HostTensorDescriptor interface renaming * Merge two static member functions in device_elementwise.hpp * Remove unary_elementwise_1d kernel and device
This commit is contained in:
2
example/33_multiple_reduce/CMakeLists.txt
Normal file
2
example/33_multiple_reduce/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_example_executable(example_dual_reduce_multiblock dual_reduce_multiblock.cpp)
|
||||
add_example_executable(example_dual_reduce_threadwise dual_reduce_threadwise.cpp)
|
||||
37
example/33_multiple_reduce/README.md
Normal file
37
example/33_multiple_reduce/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Instructions for ```example_dual_reduce```
|
||||
|
||||
## Run ```example_dual_reduce_multiblock```
|
||||
```bash
|
||||
# -D <xxx> : input 4-d tensor lengths
|
||||
# -v <x> : verification (0=no, 1=yes)
|
||||
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg2: time kernel (0=no, 1=yes)
|
||||
./bin/example_dual_reduce_multiblock -D 600,28,28,256 -v 1 2 1
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
./bin/example_dual_reduce_multiblock -D 600,28,28,256 -v 1 2 1
|
||||
launch_and_time_kernel: grid_dim {150, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 1.19529 ms, 201.499 GB/s, DeviceMultipleReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_1_InSrcVectorSize_1,OutDstVectorSize_1_1>
|
||||
```
|
||||
|
||||
## Run ```example_dual_reduce_threadwise```
|
||||
```bash
|
||||
# -D <xxx> : input 4-d tensor lengths
|
||||
# -v <x> : verification (0=no, 1=yes)
|
||||
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg2: time kernel (0=no, 1=yes)
|
||||
./bin/example_dual_reduce_multiblock -D 8000,4,4,4 -v 1 2 1
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
./bin/example_dual_reduce_threadwise -D 8000,4,4,4 -v 1 2 1
|
||||
launch_and_time_kernel: grid_dim {32, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 0.01512 ms, 71.9577 GB/s, DeviceMultipleReduceThreadwise<256,M_C256_S1,K_C1_S4,InSrcVectorDim_1_InSrcVectorSize_2,OutDstVectorSize_1_1>
|
||||
```
|
||||
313
example/33_multiple_reduce/dual_reduce_common.hpp
Normal file
313
example/33_multiple_reduce/dual_reduce_common.hpp
Normal file
@@ -0,0 +1,313 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
|
||||
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, '?'},
|
||||
{nullptr, 0, nullptr, 0}};
|
||||
|
||||
class SimpleAppArgs
|
||||
{
|
||||
private:
|
||||
int option_index = 0;
|
||||
|
||||
public:
|
||||
std::vector<size_t> inLengths = {600, 28, 28, 256};
|
||||
size_t n, h, w, c;
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
|
||||
public:
|
||||
SimpleAppArgs()
|
||||
{
|
||||
n = inLengths[0];
|
||||
h = inLengths[1];
|
||||
w = inLengths[2];
|
||||
c = inLengths[3];
|
||||
};
|
||||
|
||||
void show_usage(const char* cmd)
|
||||
{
|
||||
std::cout << "Usage of " << cmd << std::endl;
|
||||
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
|
||||
<< std::endl;
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
|
||||
"comparing with the host-based reduction"
|
||||
<< std::endl;
|
||||
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
|
||||
"value, 3=decimal value)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl;
|
||||
};
|
||||
|
||||
int processArgs(int argc, char* argv[])
|
||||
{
|
||||
using ck::host_common::getTypeValuesFromString;
|
||||
|
||||
int ch;
|
||||
|
||||
while(1)
|
||||
{
|
||||
ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index);
|
||||
if(ch == -1)
|
||||
break;
|
||||
switch(ch)
|
||||
{
|
||||
case 'D':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
inLengths = getTypeValuesFromString<size_t>(optarg);
|
||||
if(inLengths.size() != 4)
|
||||
throw std::runtime_error(
|
||||
"Invalid option format! The number of integers is incorrect!");
|
||||
|
||||
break;
|
||||
case 'v':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
do_verification = static_cast<bool>(std::atoi(optarg));
|
||||
break;
|
||||
case '?':
|
||||
if(std::string(long_options[option_index].name) == "help")
|
||||
{
|
||||
show_usage(argv[0]);
|
||||
return (-1);
|
||||
};
|
||||
break;
|
||||
default: show_usage(argv[0]); return (-1);
|
||||
};
|
||||
};
|
||||
|
||||
if(optind + 2 > argc)
|
||||
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
|
||||
|
||||
init_method = std::atoi(argv[optind++]);
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
|
||||
|
||||
n = inLengths[0];
|
||||
h = inLengths[1];
|
||||
w = inLengths[2];
|
||||
c = inLengths[3];
|
||||
|
||||
return (0);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType, typename OutDataType1, typename OutDataType2, typename AccDataType>
|
||||
static void mean_meansquare_host(const Tensor<InDataType>& in,
|
||||
Tensor<OutDataType1>& mean_ref,
|
||||
Tensor<OutDataType2>& meansquare_ref,
|
||||
size_t n,
|
||||
size_t h,
|
||||
size_t w,
|
||||
size_t c)
|
||||
|
||||
{
|
||||
auto thread_reduce_func = [&](auto iN) {
|
||||
AccDataType mean = ck::type_convert<AccDataType>(0.0f);
|
||||
AccDataType meansquare = ck::type_convert<AccDataType>(0.0f);
|
||||
|
||||
// compute mean, meanquare, variance, invVariance
|
||||
for(std::size_t iH = 0; iH < h; iH++)
|
||||
{
|
||||
for(std::size_t iW = 0; iW < w; iW++)
|
||||
{
|
||||
for(std::size_t iC = 0; iC < c; iC++)
|
||||
{
|
||||
AccDataType curr_value = ck::type_convert<AccDataType>(in(iN, iH, iW, iC));
|
||||
|
||||
mean += curr_value;
|
||||
meansquare += curr_value * curr_value;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
mean = mean / (h * w * c);
|
||||
meansquare = meansquare / (h * w * c);
|
||||
|
||||
mean_ref(iN) = ck::type_convert<OutDataType1>(mean);
|
||||
meansquare_ref(iN) = ck::type_convert<OutDataType2>(meansquare);
|
||||
};
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
std::size_t work_per_thread = (n + num_thread - 1) / num_thread;
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
|
||||
for(std::size_t it = 0; it < num_thread; it++)
|
||||
{
|
||||
std::size_t iN_begin = it * work_per_thread;
|
||||
std::size_t iN_end = std::min(static_cast<size_t>((it + 1) * work_per_thread), n);
|
||||
|
||||
auto f = [=] {
|
||||
for(std::size_t iN = iN_begin; iN < iN_end; iN++)
|
||||
{
|
||||
thread_reduce_func(iN);
|
||||
}
|
||||
};
|
||||
|
||||
threads[it] = joinable_thread(f);
|
||||
}
|
||||
};
|
||||
|
||||
using ReduceOperation = ck::reduce::Add;
|
||||
|
||||
using InElementwiseOperation_Mean = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AccElementwiseOperation_Mean = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
|
||||
using InElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using AccElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
|
||||
using InElementwiseOperationTuple =
|
||||
ck::Tuple<InElementwiseOperation_Mean, InElementwiseOperation_Meansquare>;
|
||||
using AccElementwiseOperationTuple =
|
||||
ck::Tuple<AccElementwiseOperation_Mean, AccElementwiseOperation_Meansquare>;
|
||||
|
||||
template <typename DeviceDualReduce,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
int Rank,
|
||||
int NumReduceDim>
|
||||
int mean_meansquare_dual_reduce_test(size_t n,
|
||||
size_t h,
|
||||
size_t w,
|
||||
size_t c,
|
||||
bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const std::array<int, NumReduceDim> reduceDims)
|
||||
{
|
||||
const std::vector<size_t> inLengths = {n, h, w, c};
|
||||
|
||||
Tensor<InDataType> in(inLengths);
|
||||
|
||||
std::vector<size_t> outLengths{n};
|
||||
|
||||
Tensor<OutDataType> mean_ref(outLengths);
|
||||
Tensor<OutDataType> mean(outLengths);
|
||||
Tensor<OutDataType> meansquare_ref(outLengths);
|
||||
Tensor<OutDataType> meansquare(outLengths);
|
||||
|
||||
auto inStrides = in.mDesc.GetStrides();
|
||||
auto outStrides = mean.mDesc.GetStrides();
|
||||
|
||||
size_t invariant_total_length = n;
|
||||
size_t reduce_total_length = h * w * c;
|
||||
|
||||
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f);
|
||||
const AccDataType beta = ck::type_convert<AccDataType>(0.0f);
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1: in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread); break;
|
||||
case 2: in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread); break;
|
||||
default: in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
|
||||
}
|
||||
};
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
|
||||
DeviceMem mean_dev(sizeof(OutDataType) * mean.mDesc.GetElementSpaceSize());
|
||||
DeviceMem meansquare_dev(sizeof(OutDataType) * meansquare.mDesc.GetElementSpaceSize());
|
||||
|
||||
in_dev.ToDevice(in.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
mean_meansquare_host<InDataType, OutDataType, OutDataType, AccDataType>(
|
||||
in, mean_ref, meansquare_ref, n, h, w, c);
|
||||
};
|
||||
|
||||
constexpr ck::index_t NumInputDim = Rank;
|
||||
constexpr ck::index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1;
|
||||
|
||||
std::array<ck::index_t, NumInputDim> i_inLengths;
|
||||
std::array<ck::index_t, NumInputDim> i_inStrides;
|
||||
std::array<ck::index_t, NumOutputDim> i_outLengths;
|
||||
std::array<ck::index_t, NumOutputDim> i_outStrides;
|
||||
|
||||
std::copy(inLengths.begin(), inLengths.end(), i_inLengths.begin());
|
||||
std::copy(inStrides.begin(), inStrides.end(), i_inStrides.begin());
|
||||
std::copy(outLengths.begin(), outLengths.end(), i_outLengths.begin());
|
||||
std::copy(outStrides.begin(), outStrides.end(), i_outStrides.begin());
|
||||
|
||||
auto dual_reduce_op = DeviceDualReduce{};
|
||||
|
||||
auto argument_ptr = dual_reduce_op.MakeArgumentPointer(
|
||||
i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
{i_outStrides, i_outStrides},
|
||||
reduceDims,
|
||||
{&alpha, &alpha},
|
||||
{&beta, &beta},
|
||||
in_dev.GetDeviceBuffer(),
|
||||
{mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()},
|
||||
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
|
||||
ck::make_tuple(
|
||||
AccElementwiseOperation_Mean{static_cast<int32_t>(reduce_total_length)},
|
||||
AccElementwiseOperation_Meansquare{static_cast<int32_t>(reduce_total_length)}));
|
||||
|
||||
if(!dual_reduce_op.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout
|
||||
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
|
||||
<< std::endl;
|
||||
return (-1);
|
||||
};
|
||||
|
||||
std::string reduce_name = dual_reduce_op.GetTypeString();
|
||||
|
||||
auto invoker_ptr = dual_reduce_op.MakeInvokerPointer();
|
||||
|
||||
float avg_time = 0.0f;
|
||||
|
||||
avg_time += invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) +
|
||||
2 * invariant_total_length * sizeof(OutDataType);
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
mean_dev.FromDevice(mean.mData.data());
|
||||
meansquare_dev.FromDevice(meansquare.mData.data());
|
||||
pass = pass && ck::utils::check_err(mean.mData, mean_ref.mData);
|
||||
pass = pass && ck::utils::check_err(meansquare.mData, meansquare_ref.mData);
|
||||
};
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
98
example/33_multiple_reduce/dual_reduce_multiblock.cpp
Normal file
98
example/33_multiple_reduce/dual_reduce_multiblock.cpp
Normal file
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_multiple_reduce_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
|
||||
#include "dual_reduce_common.hpp"
|
||||
|
||||
using namespace ck;
|
||||
using namespace ck::tensor_operation::device;
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = float;
|
||||
using OutDataTypeTuple = Tuple<OutDataType, OutDataType>;
|
||||
using AccDataType = float;
|
||||
|
||||
// for NHWC layer-norm calculation of mean and meansquare
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
constexpr bool PropagateNan = false;
|
||||
|
||||
constexpr InMemoryDataOperationEnum OutMemoryDataOperation = InMemoryDataOperationEnum::Set;
|
||||
|
||||
using DeviceDualReduce = DeviceMultipleReduceMultiBlock<2,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
OutDataTypeTuple,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple,
|
||||
OutMemoryDataOperation,
|
||||
PropagateNan,
|
||||
256,
|
||||
4,
|
||||
64,
|
||||
1,
|
||||
1,
|
||||
1, // InSrcVectorDim
|
||||
1,
|
||||
ck::Sequence<1, 1>>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int retval = 0;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
SimpleAppArgs arg;
|
||||
|
||||
if(arg.processArgs(argc, argv) < 0)
|
||||
return (-1);
|
||||
|
||||
std::array<int, NumReduceDim> reduceDims = {1, 2, 3};
|
||||
|
||||
retval = mean_meansquare_dual_reduce_test<DeviceDualReduce,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
Rank,
|
||||
NumReduceDim>(arg.n,
|
||||
arg.h,
|
||||
arg.w,
|
||||
arg.c,
|
||||
arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
reduceDims);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::array<int, NumReduceDim> reduceDims = {1, 2, 3};
|
||||
|
||||
retval = mean_meansquare_dual_reduce_test<DeviceDualReduce,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
Rank,
|
||||
NumReduceDim>(
|
||||
600, 28, 28, 256, true, 2, true, reduceDims);
|
||||
};
|
||||
|
||||
return (retval);
|
||||
}
|
||||
93
example/33_multiple_reduce/dual_reduce_threadwise.cpp
Normal file
93
example/33_multiple_reduce/dual_reduce_threadwise.cpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_multiple_reduce_threadwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
|
||||
#include "dual_reduce_common.hpp"
|
||||
|
||||
using namespace ck;
|
||||
using namespace ck::tensor_operation::device;
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = float;
|
||||
using OutDataTypeTuple = Tuple<OutDataType, OutDataType>;
|
||||
using AccDataType = float;
|
||||
|
||||
// for NHWC layer-norm calculation of mean and meansquare
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
constexpr bool PropagateNan = false;
|
||||
|
||||
using DeviceDualReduce = DeviceMultipleReduceThreadWise<2,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
OutDataTypeTuple,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple,
|
||||
PropagateNan,
|
||||
256,
|
||||
1,
|
||||
4,
|
||||
1, // InSrcVectorDim
|
||||
2,
|
||||
ck::Sequence<1, 1>>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int retval = 0;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
SimpleAppArgs arg;
|
||||
|
||||
if(arg.processArgs(argc, argv) < 0)
|
||||
return (-1);
|
||||
|
||||
std::array<int, NumReduceDim> reduceDims = {1, 2, 3};
|
||||
|
||||
retval = mean_meansquare_dual_reduce_test<DeviceDualReduce,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
Rank,
|
||||
NumReduceDim>(arg.n,
|
||||
arg.h,
|
||||
arg.w,
|
||||
arg.c,
|
||||
arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
reduceDims);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::array<int, NumReduceDim> reduceDims = {1, 2, 3};
|
||||
|
||||
retval = mean_meansquare_dual_reduce_test<DeviceDualReduce,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
Rank,
|
||||
NumReduceDim>(
|
||||
8000, 4, 4, 4, true, 2, true, reduceDims);
|
||||
};
|
||||
|
||||
return (retval);
|
||||
}
|
||||
Reference in New Issue
Block a user