mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Pr82 followup (#115)
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny change in script/profile_reduce_with_index.sh * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim
This commit is contained in:
60
example/12_reduce/README.md
Normal file
60
example/12_reduce/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Instructions for ```reduce_blockwise``` Example
|
||||
|
||||
## Docker script
|
||||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Build ```reduce_blockwise```
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
```bash
|
||||
# Need to specify target ID, example below is gfx908
|
||||
cmake \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
```
|
||||
|
||||
```bash
|
||||
make -j reduce_blockwise
|
||||
```
|
||||
|
||||
## Run ```reduce_blockwise```
|
||||
```bash
|
||||
# -D <xxx> : input 4-d tensor lengths
|
||||
# -v <x> : verification (0=no, 1=yes)
|
||||
#arg1: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg2: run kernel # of times (>1)
|
||||
./bin/reduce_blockwise -D 16,64,32,960 -v 1 1 10
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 3 times...
|
||||
Perf: 0.23536 ms, 267.32 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
|
||||
error: 0
|
||||
max_diff: 0, 529, 529
|
||||
root@dc-smc-18:/data/composable_kernel/Build3# bin/reduce_blockwise -D 16,64,32,960 -v 1 1 10
|
||||
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 10 times...
|
||||
Perf: 0.23392 ms, 268.966 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
|
||||
error: 0
|
||||
max_diff: 0, 528, 528
|
||||
```
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "device_reduce_blockwise.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
#include "host_generic_reduction.hpp"
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
|
||||
@@ -28,8 +29,8 @@ using kInDataType = ck::half_t;
|
||||
using kOutDataType = ck::half_t;
|
||||
using kAccDataType = float;
|
||||
|
||||
constexpr int Rank = 4;
|
||||
using ReduceDims_ = ck::Sequence<0, 1, 2>;
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
constexpr ReduceTensorOp_t ReduceOpId = ReduceTensorOp_t::NORM2;
|
||||
constexpr NanPropagation_t NanOpt = NanPropagation_t::PROPAGATE_NAN;
|
||||
@@ -46,7 +47,7 @@ using DeviceReduceInstance = DeviceReduceBlockWise<kInDataType,
|
||||
kAccDataType,
|
||||
kOutDataType,
|
||||
Rank,
|
||||
ReduceDims_,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
@@ -192,39 +193,13 @@ class SimpleAppArgs
|
||||
};
|
||||
};
|
||||
|
||||
template <int Rank, typename ReduceDims>
|
||||
static std::vector<int> get_reduce_dims()
|
||||
{
|
||||
std::vector<int> resDims;
|
||||
|
||||
static_for<0, ReduceDims::Size(), 1>{}([&](auto i) { resDims.push_back(ReduceDims::At(i)); });
|
||||
|
||||
return (resDims);
|
||||
};
|
||||
|
||||
template <int Rank, typename ReduceDims>
|
||||
static std::vector<int> get_invariant_dims()
|
||||
{
|
||||
std::vector<int> resDims;
|
||||
unsigned int incFlag = 0;
|
||||
|
||||
static_for<0, ReduceDims::Size(), 1>{}(
|
||||
[&](auto i) { incFlag = incFlag | (0x1 << ReduceDims::At(i)); });
|
||||
|
||||
for(int dim = 0; dim < Rank; dim++)
|
||||
{
|
||||
if(incFlag & (0x1 << dim))
|
||||
continue;
|
||||
resDims.push_back(dim);
|
||||
};
|
||||
|
||||
return (resDims);
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
const std::vector<int> reduceDims{0, 1, 2};
|
||||
const std::vector<int> invariantDims{3};
|
||||
|
||||
SimpleAppArgs args;
|
||||
|
||||
if(args.processArgs(argc, argv) < 0)
|
||||
@@ -260,15 +235,12 @@ int main(int argc, char* argv[])
|
||||
|
||||
Tensor<InDataType> in(args.inLengths);
|
||||
|
||||
const std::vector<int> InvariantDims = get_invariant_dims<Rank, ReduceDims_>();
|
||||
const std::vector<int> ReduceDims = get_reduce_dims<Rank, ReduceDims_>();
|
||||
|
||||
std::vector<size_t> outLengths;
|
||||
|
||||
if(InvariantDims.empty())
|
||||
if(invariantDims.empty())
|
||||
outLengths.push_back(1);
|
||||
else
|
||||
for(auto dim : InvariantDims)
|
||||
for(auto dim : invariantDims)
|
||||
outLengths.push_back(args.inLengths[dim]);
|
||||
|
||||
Tensor<OutDataType> out_ref(outLengths);
|
||||
@@ -328,7 +300,7 @@ int main(int argc, char* argv[])
|
||||
if(args.do_verification)
|
||||
{
|
||||
ReductionHost<InDataType, AccDataType, OutDataType, ReduceOpId, PropagateNan, NeedIndices>
|
||||
hostReduce(in.mDesc, out_ref.mDesc, InvariantDims, ReduceDims);
|
||||
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(
|
||||
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
|
||||
@@ -350,6 +322,7 @@ int main(int argc, char* argv[])
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
|
||||
55
example/13_pool2d_fwd/README.md
Normal file
55
example/13_pool2d_fwd/README.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Instructions for ```pool2d_fwd``` Example
|
||||
|
||||
## Docker script
|
||||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Build ```pool2d_fwd```
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
```bash
|
||||
# Need to specify target ID, example below is gfx908
|
||||
cmake \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
```
|
||||
|
||||
```bash
|
||||
make -j pool2d_fwd
|
||||
```
|
||||
|
||||
## Run ```pool2d_fwd```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg3: run kernel # of times (>1)
|
||||
#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx
|
||||
./example/pool2d_fwd 1 1 10
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
|
||||
out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192}
|
||||
launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1}
|
||||
Warm up
|
||||
Start running 10 times...
|
||||
Perf: 0.415453 ms, 1.37996 TFlops, 749.726 GB/s
|
||||
error: 0
|
||||
max_diff: 0, 1, 1
|
||||
```
|
||||
@@ -32,57 +32,53 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
|
||||
#include "cluster_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Buffer1dDescType,
|
||||
typename AccDataType,
|
||||
template <typename AccDataType,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
bool ReorderThreadClusters,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
struct PartitionedBlockwiseReductionOn1dBuffer
|
||||
struct PartitionedBlockwiseReduction
|
||||
{
|
||||
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
|
||||
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
"The product of cluster lengths should be same as BlockSize!");
|
||||
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
|
||||
"The buffer size should be the same as BlockSize!");
|
||||
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
|
||||
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
|
||||
|
||||
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename BufferType>
|
||||
__device__ static void Reduce(BufferType& block_buffer,
|
||||
AccDataType& accuData,
|
||||
index_t thread_m_cluster_id,
|
||||
index_t thread_k_cluster_id)
|
||||
__device__ static void Reduce(BufferType& block_buffer, AccDataType& accuData)
|
||||
{
|
||||
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
|
||||
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
|
||||
|
||||
if(thread_k_cluster_id < indOffset)
|
||||
{
|
||||
// consider the thread clusters order, ensure the contiguous locations are accessed
|
||||
// by contiguous Thread-ID
|
||||
index_t offset1 =
|
||||
ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
|
||||
index_t offset2 = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
|
||||
thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize +
|
||||
(thread_k_cluster_id + indOffset)));
|
||||
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
|
||||
@@ -93,34 +89,34 @@ struct PartitionedBlockwiseReductionOn1dBuffer
|
||||
__syncthreads();
|
||||
});
|
||||
|
||||
index_t offset = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize));
|
||||
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
|
||||
|
||||
accuData = type_convert<AccDataType>(block_buffer[offset]);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Buffer1dDescType,
|
||||
typename AccDataType,
|
||||
template <typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
bool ReorderThreadClusters,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
|
||||
struct PartitionedBlockwiseReductionWithIndex
|
||||
{
|
||||
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
|
||||
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
"The product of cluster lengths should be same as BlockSize!");
|
||||
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
|
||||
"The buffer size should be the same as BlockSize!");
|
||||
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
|
||||
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
|
||||
|
||||
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
@@ -130,32 +126,24 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
|
||||
__device__ static void Reduce(BufferType& block_val_buffer,
|
||||
IdxBufferType& block_idx_buffer,
|
||||
AccDataType& accuData,
|
||||
IndexDataType& accuIndex,
|
||||
index_t thread_m_cluster_id,
|
||||
index_t thread_k_cluster_id)
|
||||
IndexDataType& accuIndex)
|
||||
{
|
||||
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
|
||||
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << I();
|
||||
|
||||
if(thread_k_cluster_id % (indOffset * 2) == 0)
|
||||
{
|
||||
// consider the thread clusters order, ensure the contiguous locations are accessed
|
||||
// by contiguous Thread-ID
|
||||
index_t offset1 =
|
||||
ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
|
||||
index_t offset2 = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
|
||||
thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize +
|
||||
(thread_k_cluster_id + indOffset)));
|
||||
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
|
||||
@@ -170,10 +158,7 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
|
||||
__syncthreads();
|
||||
});
|
||||
|
||||
index_t offset = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize));
|
||||
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
|
||||
|
||||
accuData = type_convert<AccDataType>(block_val_buffer[offset]);
|
||||
accuIndex = block_idx_buffer[offset];
|
||||
|
||||
@@ -36,14 +36,15 @@ struct DeviceReduce : public BaseOperator
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& inElementwiseOp,
|
||||
const AccElementwiseOperation& accElementwiseOp) = 0;
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -15,8 +15,8 @@ namespace device {
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
@@ -40,7 +40,12 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
@@ -74,7 +79,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
@@ -82,7 +87,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
@@ -136,6 +141,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -144,30 +150,31 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
@@ -305,6 +312,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -318,6 +326,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
|
||||
@@ -15,8 +15,8 @@ namespace device {
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
@@ -45,7 +45,11 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
std::is_same<InDataType, AccDataType>::value,
|
||||
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
|
||||
@@ -117,16 +121,16 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
|
||||
: inLengths_(inLengths),
|
||||
inStrides_(inStrides),
|
||||
outLengths_(outLengths),
|
||||
outStrides_(outStrides),
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
in_elementwise_op_(in_elementwise_op),
|
||||
acc_elementwise_op_(acc_elementwise_op)
|
||||
{
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
@@ -268,6 +272,7 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -277,6 +282,8 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
{
|
||||
(void)reduceDims;
|
||||
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define DEVICE_REDUCE_COMMON_HPP
|
||||
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
@@ -40,23 +41,6 @@ constexpr bool belong()
|
||||
return (inside);
|
||||
};
|
||||
|
||||
template <int Rank, typename ReduceDims, int start = 0>
|
||||
constexpr auto get_invariant_dims()
|
||||
{
|
||||
static_assert(Rank <= 6, "bigger Rank size not supported!");
|
||||
|
||||
if constexpr(start >= Rank)
|
||||
return Sequence<>{};
|
||||
else
|
||||
{
|
||||
if constexpr(!belong<start, ReduceDims>())
|
||||
return merge_sequences(Sequence<start>{},
|
||||
get_invariant_dims<Rank, ReduceDims, start + 1>());
|
||||
else
|
||||
return get_invariant_dims<Rank, ReduceDims, start + 1>();
|
||||
};
|
||||
};
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
static auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
|
||||
@@ -74,6 +58,45 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumReduceDim>
|
||||
static inline std::pair<std::vector<int>, std::vector<int>>
|
||||
shuffle_tensor_dimensions(const std::vector<int>& dimLengths,
|
||||
const std::vector<int>& dimStrides,
|
||||
const std::vector<int>& reduceDims)
|
||||
{
|
||||
std::vector<int> newDimLengths;
|
||||
std::vector<int> newDimStrides;
|
||||
|
||||
assert(Rank == dimLengths.size() && Rank == dimStrides.size() &&
|
||||
NumReduceDim == reduceDims.size());
|
||||
|
||||
int reduceFlag = 0;
|
||||
|
||||
// flag the bits for the reduceDims
|
||||
for(int i = 0; i < NumReduceDim; i++)
|
||||
{
|
||||
reduceFlag |= 1 << reduceDims[i];
|
||||
};
|
||||
|
||||
// collect invariant dimensions
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) == 0)
|
||||
{
|
||||
newDimLengths.push_back(dimLengths[i]);
|
||||
newDimStrides.push_back(dimStrides[i]);
|
||||
};
|
||||
|
||||
// collect reduce dimensions
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) > 0)
|
||||
{
|
||||
newDimLengths.push_back(dimLengths[i]);
|
||||
newDimStrides.push_back(dimStrides[i]);
|
||||
};
|
||||
|
||||
return std::make_pair(newDimLengths, newDimStrides);
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ namespace device {
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
@@ -84,7 +89,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
@@ -92,7 +97,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
@@ -147,6 +152,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -155,31 +161,31 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
(void)out_indices_dev;
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
@@ -369,6 +375,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -382,6 +389,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
|
||||
@@ -15,8 +15,8 @@ namespace device {
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
@@ -112,7 +117,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
@@ -120,7 +125,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
@@ -161,10 +166,11 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<index_t>& inLengths,
|
||||
const std::vector<index_t>& inStrides,
|
||||
const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides,
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -173,31 +179,30 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev},
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
workspace_dev_{workspace_dev}
|
||||
workspace_dev_{workspace_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
@@ -370,6 +375,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -383,6 +389,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
|
||||
@@ -16,7 +16,7 @@ template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
typename ReduceDims,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
@@ -40,7 +40,12 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
@@ -74,7 +79,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
@@ -82,7 +87,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
@@ -136,6 +141,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -144,30 +150,32 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
@@ -306,6 +314,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -319,6 +328,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
|
||||
@@ -31,8 +31,8 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -158,13 +158,27 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -180,14 +194,12 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
@@ -221,31 +233,31 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
@@ -283,21 +295,14 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
@@ -380,15 +385,13 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
@@ -432,31 +435,31 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
index_t indexOffset = 0;
|
||||
|
||||
@@ -503,29 +506,15 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
@@ -648,15 +637,13 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
Sequence<MThreadClusterSize, KThreadClusterSize>,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
@@ -707,46 +694,48 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_val_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_idx_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
// index_t indexOffset = 0;
|
||||
|
||||
@@ -787,29 +776,15 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
|
||||
@@ -86,22 +86,34 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using blockwise_reduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -145,12 +157,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
@@ -158,17 +170,16 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -212,21 +223,14 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
// consistent reduced result for that invariant dimension. due to the using of vector_load,
|
||||
// each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_reduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -103,13 +103,27 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer1dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -124,14 +138,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
@@ -168,12 +180,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
@@ -181,17 +193,16 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -233,21 +244,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
// Each block executes multiple parallel reductions on the LDS, and due to the using of
|
||||
// vector_load, each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
@@ -290,15 +294,13 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
@@ -346,12 +348,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
@@ -359,17 +361,16 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -418,29 +419,15 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
|
||||
@@ -101,6 +101,9 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_threadwise
|
||||
{
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
@@ -147,17 +150,17 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
@@ -299,17 +302,17 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
int NumReduceDim,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
@@ -91,7 +91,7 @@ void add_device_reduce_instance_blockwise(
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
@@ -112,34 +112,36 @@ void add_device_reduce_instance_blockwise(
|
||||
});
|
||||
};
|
||||
|
||||
#define ADD_BLOCKWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_blockwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
#define ADD_BLOCKWISE_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
template void add_device_reduce_instance_blockwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
#define ADD_BLOCKWISE_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
NumReduceDim)
|
||||
|
||||
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_blockwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
@@ -149,15 +151,16 @@ void add_device_reduce_instance_blockwise(
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
NumReduceDim)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
|
||||
@@ -11,25 +11,25 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,16 +11,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,34 +11,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,16 +11,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,34 +11,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -45,7 +45,7 @@ template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
int NumReduceDim,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
@@ -86,7 +86,7 @@ void add_device_reduce_instance_blockwise_second_call(
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
@@ -106,21 +106,21 @@ void add_device_reduce_instance_blockwise_second_call(
|
||||
});
|
||||
};
|
||||
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_blockwise_second_call<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceBlockWiseSecondCallPtrType<compT, ReduceOpId>> & \
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
template void add_device_reduce_instance_blockwise_second_call<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceBlockWiseSecondCallPtrType<compT, ReduceOpId>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
@@ -128,27 +128,27 @@ void add_device_reduce_instance_blockwise_second_call(
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
NumReduceDim)
|
||||
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
extern template void add_device_reduce_instance_blockwise_second_call<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector< \
|
||||
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
|
||||
InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_blockwise_second_call<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector< \
|
||||
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
|
||||
InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
@@ -156,7 +156,7 @@ void add_device_reduce_instance_blockwise_second_call(
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
NumReduceDim)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
|
||||
@@ -11,25 +11,25 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,16 +11,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,34 +11,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,16 +11,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,34 +11,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -59,7 +59,7 @@ template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
int NumReduceDim,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
@@ -110,7 +110,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
@@ -132,21 +132,21 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
}
|
||||
};
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_multiblock_atomic_add<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
template void add_device_reduce_instance_multiblock_atomic_add<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
@@ -154,15 +154,15 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
NumReduceDim)
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
@@ -173,7 +173,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
@@ -181,7 +181,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
NumReduceDim)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
|
||||
@@ -11,13 +11,13 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,13 +11,13 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,13 +11,13 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -55,7 +55,7 @@ template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
int NumReduceDim,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
@@ -93,7 +93,7 @@ void add_device_reduce_instance_multiblock_partial_reduce(
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
@@ -113,21 +113,21 @@ void add_device_reduce_instance_multiblock_partial_reduce(
|
||||
});
|
||||
};
|
||||
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceMultiBlockPartialReducePtrType<compT, ReduceOpId>> & \
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceMultiBlockPartialReducePtrType<compT, ReduceOpId>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
@@ -135,28 +135,27 @@ void add_device_reduce_instance_multiblock_partial_reduce(
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
NumReduceDim)
|
||||
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
extern template void \
|
||||
add_device_reduce_instance_multiblock_partial_reduce<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector< \
|
||||
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
|
||||
InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector< \
|
||||
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
|
||||
InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
@@ -164,7 +163,7 @@ void add_device_reduce_instance_multiblock_partial_reduce(
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
NumReduceDim)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
|
||||
@@ -11,25 +11,25 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,16 +11,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,29 +11,29 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,10 +11,10 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,37 +11,37 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
|
||||
// Will be moved to use MultiBlockAtomicAdd
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -57,7 +57,7 @@ template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
int NumReduceDim,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
@@ -89,7 +89,7 @@ void add_device_reduce_instance_threadwise(
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
@@ -108,34 +108,36 @@ void add_device_reduce_instance_threadwise(
|
||||
});
|
||||
};
|
||||
|
||||
#define ADD_THREADWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_threadwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
#define ADD_THREADWISE_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
template void add_device_reduce_instance_threadwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_THREADWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_THREADWISE_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
#define ADD_THREADWISE_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_THREADWISE_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
NumReduceDim)
|
||||
|
||||
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_threadwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
@@ -145,15 +147,16 @@ void add_device_reduce_instance_threadwise(
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_THREADWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_THREADWISE_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
#define ADD_THREADWISE_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
ADD_THREADWISE_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
NumReduceDim)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
|
||||
@@ -11,25 +11,25 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,16 +11,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,34 +11,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,16 +11,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -11,34 +11,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,25 +6,25 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,16 +6,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,34 +6,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,16 +6,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,34 +6,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,25 +6,25 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,16 +6,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,34 +6,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,16 +6,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,34 +6,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,13 +6,13 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,13 +6,13 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,13 +6,13 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,25 +6,25 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,16 +6,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,29 +6,29 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,10 +6,10 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,37 +6,37 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
|
||||
// Will be moved to use MultiBlockAtomicAdd
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,25 +6,25 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,16 +6,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,34 +6,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,16 +6,16 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -6,34 +6,34 @@ namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
|
||||
@@ -9,54 +9,52 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
template <int Rank, typename ReduceDims, int ReduceOpId, int NanOpt, int IndicesOpt>
|
||||
template <int Rank, int NumReduceDim, int ReduceOpId, int NanOpt, int IndicesOpt>
|
||||
struct ReduceDescription
|
||||
{
|
||||
static constexpr int Rank_ = Rank;
|
||||
static constexpr int ReduceOpId_ = ReduceOpId;
|
||||
static constexpr int NanOpt_ = NanOpt;
|
||||
static constexpr int IndicesOpt_ = IndicesOpt;
|
||||
|
||||
using ReduceDims_ = ReduceDims;
|
||||
static constexpr int Rank_ = Rank;
|
||||
static constexpr int NumReduceDim_ = NumReduceDim;
|
||||
static constexpr int ReduceOpId_ = ReduceOpId;
|
||||
static constexpr int NanOpt_ = NanOpt;
|
||||
static constexpr int IndicesOpt_ = IndicesOpt;
|
||||
};
|
||||
|
||||
using reduce_description_instances =
|
||||
std::tuple<ReduceDescription<4, Sequence<0, 1, 2>, 0, 0, 0>, // for ADD
|
||||
ReduceDescription<4, Sequence<0>, 0, 0, 0>,
|
||||
ReduceDescription<2, Sequence<1>, 0, 0, 0>,
|
||||
using reduce_description_instances = std::tuple<ReduceDescription<4, 3, 0, 0, 0>, // for ADD
|
||||
ReduceDescription<4, 1, 0, 0, 0>,
|
||||
ReduceDescription<2, 1, 0, 0, 0>,
|
||||
|
||||
ReduceDescription<4, Sequence<0, 1, 2>, 5, 0, 0>, // for AVG
|
||||
ReduceDescription<4, Sequence<0>, 5, 0, 0>,
|
||||
ReduceDescription<2, Sequence<1>, 5, 0, 0>,
|
||||
ReduceDescription<4, 3, 5, 0, 0>, // for AVG
|
||||
ReduceDescription<4, 1, 5, 0, 0>,
|
||||
ReduceDescription<2, 1, 5, 0, 0>,
|
||||
|
||||
ReduceDescription<4, Sequence<0, 1, 2>, 7, 0, 0>, // for NORM2
|
||||
ReduceDescription<4, Sequence<0>, 7, 0, 0>,
|
||||
ReduceDescription<2, Sequence<1>, 7, 0, 0>,
|
||||
ReduceDescription<4, 3, 7, 0, 0>, // for NORM2
|
||||
ReduceDescription<4, 1, 7, 0, 0>,
|
||||
ReduceDescription<2, 1, 7, 0, 0>,
|
||||
|
||||
ReduceDescription<4, Sequence<0, 1, 2>, 2, 0, 0>, // for MIN
|
||||
ReduceDescription<4, Sequence<0>, 2, 0, 0>,
|
||||
ReduceDescription<2, Sequence<1>, 2, 0, 0>,
|
||||
ReduceDescription<4, Sequence<0, 1, 2>, 3, 0, 0>, // for MAX
|
||||
ReduceDescription<4, Sequence<0>, 3, 0, 0>,
|
||||
ReduceDescription<2, Sequence<1>, 3, 0, 0>,
|
||||
ReduceDescription<4, Sequence<0, 1, 2>, 4, 0, 0>, // for AMAX
|
||||
ReduceDescription<4, Sequence<0>, 4, 0, 0>,
|
||||
ReduceDescription<2, Sequence<1>, 4, 0, 0>,
|
||||
ReduceDescription<4, 3, 2, 0, 0>, // for MIN
|
||||
ReduceDescription<4, 1, 2, 0, 0>,
|
||||
ReduceDescription<2, 1, 2, 0, 0>,
|
||||
ReduceDescription<4, 3, 3, 0, 0>, // for MAX
|
||||
ReduceDescription<4, 1, 3, 0, 0>,
|
||||
ReduceDescription<2, 1, 3, 0, 0>,
|
||||
ReduceDescription<4, 3, 4, 0, 0>, // for AMAX
|
||||
ReduceDescription<4, 1, 4, 0, 0>,
|
||||
ReduceDescription<2, 1, 4, 0, 0>,
|
||||
|
||||
ReduceDescription<4, Sequence<0, 1, 2>, 2, 0, 1>, // for MIN
|
||||
ReduceDescription<4, Sequence<0>, 2, 0, 1>,
|
||||
ReduceDescription<2, Sequence<1>, 2, 0, 1>,
|
||||
ReduceDescription<4, Sequence<0, 1, 2>, 3, 0, 1>, // for MAX
|
||||
ReduceDescription<4, Sequence<0>, 3, 0, 1>,
|
||||
ReduceDescription<2, Sequence<1>, 3, 0, 1>,
|
||||
ReduceDescription<4, Sequence<0, 1, 2>, 4, 0, 1>, // for AMAX
|
||||
ReduceDescription<4, Sequence<0>, 4, 0, 1>,
|
||||
ReduceDescription<2, Sequence<1>, 4, 0, 1>>;
|
||||
ReduceDescription<4, 3, 2, 0, 1>, // for MIN
|
||||
ReduceDescription<4, 1, 2, 0, 1>,
|
||||
ReduceDescription<2, 1, 2, 0, 1>,
|
||||
ReduceDescription<4, 3, 3, 0, 1>, // for MAX
|
||||
ReduceDescription<4, 1, 3, 0, 1>,
|
||||
ReduceDescription<2, 1, 3, 0, 1>,
|
||||
ReduceDescription<4, 3, 4, 0, 1>, // for AMAX
|
||||
ReduceDescription<4, 1, 4, 0, 1>,
|
||||
ReduceDescription<2, 1, 4, 0, 1>>;
|
||||
|
||||
template <typename DescriptionType>
|
||||
bool description_match(const DescriptionType& description,
|
||||
int Rank,
|
||||
const std::vector<int>& ReduceDims,
|
||||
const std::vector<int>& reduceDims,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt)
|
||||
@@ -66,16 +64,11 @@ bool description_match(const DescriptionType& description,
|
||||
description.IndicesOpt_ != static_cast<int>(IndicesOpt))
|
||||
return (false);
|
||||
|
||||
if(DescriptionType::ReduceDims_::Size() != ReduceDims.size())
|
||||
if(DescriptionType::NumReduceDim_ != reduceDims.size())
|
||||
return (false);
|
||||
|
||||
bool result = true;
|
||||
|
||||
static_for<0, DescriptionType::ReduceDims_::Size(), 1>{}([&](auto i) {
|
||||
if(DescriptionType::ReduceDims_::At(i) != ReduceDims[i])
|
||||
result = false;
|
||||
});
|
||||
|
||||
return (result);
|
||||
};
|
||||
|
||||
@@ -87,33 +80,29 @@ bool description_match(const DescriptionType& description,
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <int Rank, typename ReduceDims>
|
||||
static std::vector<int> get_reduce_dims()
|
||||
template <index_t Rank, index_t NumReduceDim>
|
||||
static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduceDims)
|
||||
{
|
||||
std::vector<int> resDims;
|
||||
assert(NumReduceDim == reduceDims.size());
|
||||
|
||||
static_for<0, ReduceDims::Size(), 1>{}([&](auto i) { resDims.push_back(ReduceDims::At(i)); });
|
||||
int reduceFlag = 0;
|
||||
|
||||
return (resDims);
|
||||
};
|
||||
|
||||
template <int Rank, typename ReduceDims>
|
||||
static std::vector<int> get_invariant_dims()
|
||||
{
|
||||
std::vector<int> resDims;
|
||||
unsigned int incFlag = 0;
|
||||
|
||||
static_for<0, ReduceDims::Size(), 1>{}(
|
||||
[&](auto i) { incFlag = incFlag | (0x1 << ReduceDims::At(i)); });
|
||||
|
||||
for(int dim = 0; dim < Rank; dim++)
|
||||
// flag the bits for the reduceDims
|
||||
for(int i = 0; i < NumReduceDim; i++)
|
||||
{
|
||||
if(incFlag & (0x1 << dim))
|
||||
continue;
|
||||
resDims.push_back(dim);
|
||||
reduceFlag |= 1 << reduceDims[i];
|
||||
};
|
||||
|
||||
return (resDims);
|
||||
std::vector<int> invariantDims;
|
||||
|
||||
// collect invariant dimensions
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) == 0)
|
||||
{
|
||||
invariantDims.push_back(i);
|
||||
};
|
||||
|
||||
return invariantDims;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -149,7 +138,7 @@ template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims_,
|
||||
int NumReduceDim,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
@@ -159,6 +148,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
bool do_dumpout,
|
||||
int nrepeat,
|
||||
const std::vector<size_t>& inLengths,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta)
|
||||
{
|
||||
@@ -203,15 +193,14 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
{
|
||||
Tensor<InDataType> in(inLengths);
|
||||
|
||||
const std::vector<int> OuterDims = get_invariant_dims<Rank, ReduceDims_>();
|
||||
const std::vector<int> ReduceDims = get_reduce_dims<Rank, ReduceDims_>();
|
||||
|
||||
std::vector<size_t> outLengths;
|
||||
|
||||
if(OuterDims.empty())
|
||||
const auto invariantDims = get_invariant_dims<Rank, NumReduceDim>(reduceDims);
|
||||
|
||||
if(reduceDims.size() == Rank)
|
||||
outLengths.push_back(1);
|
||||
else
|
||||
for(auto dim : OuterDims)
|
||||
for(auto dim : invariantDims)
|
||||
outLengths.push_back(inLengths[dim]);
|
||||
|
||||
Tensor<OutDataType> out_ref(outLengths);
|
||||
@@ -302,7 +291,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims_,
|
||||
NumReduceDim,
|
||||
ReduceOpId,
|
||||
NanOpt,
|
||||
IndicesOpt>(reduce0_ptrs);
|
||||
@@ -311,7 +300,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims_,
|
||||
NumReduceDim,
|
||||
ReduceOpId,
|
||||
NanOpt,
|
||||
IndicesOpt>(reduce0_ptrs);
|
||||
@@ -321,7 +310,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims_,
|
||||
NumReduceDim,
|
||||
ReduceOpId,
|
||||
NanOpt,
|
||||
IndicesOpt>(reduce0_ptrs);
|
||||
@@ -330,7 +319,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims_,
|
||||
NumReduceDim,
|
||||
ReduceOpId,
|
||||
NanOpt,
|
||||
IndicesOpt>(reduce1_ptrs);
|
||||
@@ -341,7 +330,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims_,
|
||||
NumReduceDim,
|
||||
ReduceOpId,
|
||||
NanOpt,
|
||||
IndicesOpt>(reduce2_ptrs);
|
||||
@@ -358,7 +347,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
using hCompType = typename type_mapping<AccDataType>::outDataType;
|
||||
|
||||
ReductionHost<hInType, hCompType, hOutType, ReduceOpId, PropagateNan, NeedIndices>
|
||||
hostReduce(in.mDesc, out_ref.mDesc, OuterDims, ReduceDims);
|
||||
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(alpha,
|
||||
reinterpret_cast<const hInType*>(in.mData.data()),
|
||||
@@ -383,6 +372,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
@@ -464,6 +454,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
@@ -496,6 +487,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
inStrides2,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
@@ -584,7 +576,7 @@ void profile_reduce_impl(bool do_verification,
|
||||
bool do_dumpout,
|
||||
int nrepeat,
|
||||
const std::vector<size_t>& inLengths,
|
||||
const std::vector<int>& ReduceDims,
|
||||
const std::vector<int>& reduceDims,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt,
|
||||
@@ -605,18 +597,26 @@ void profile_reduce_impl(bool do_verification,
|
||||
using descType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
|
||||
|
||||
if(!description_match(
|
||||
descType{}, inLengths.size(), ReduceDims, ReduceOpId, NanOpt, IndicesOpt))
|
||||
descType{}, inLengths.size(), reduceDims, ReduceOpId, NanOpt, IndicesOpt))
|
||||
return;
|
||||
|
||||
profile_reduce_impl_impl<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
descType::Rank_,
|
||||
typename descType::ReduceDims_,
|
||||
descType::NumReduceDim_,
|
||||
static_cast<ReduceTensorOp_t>(descType::ReduceOpId_),
|
||||
static_cast<NanPropagation_t>(descType::NanOpt_),
|
||||
static_cast<ReduceTensorIndices_t>(descType::IndicesOpt_)>(
|
||||
do_verification, init_method, do_log, do_dumpout, nrepeat, inLengths, alpha, beta);
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
do_dumpout,
|
||||
nrepeat,
|
||||
inLengths,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta);
|
||||
|
||||
matched = true;
|
||||
});
|
||||
|
||||
@@ -25,7 +25,7 @@ using ck::ReduceTensorIndices_t;
|
||||
using ck::ReduceTensorOp_t;
|
||||
|
||||
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
|
||||
{"toReduceDims", required_argument, nullptr, 'R'},
|
||||
{"reduceDims", required_argument, nullptr, 'R'},
|
||||
{"reduceOp", required_argument, nullptr, 'O'},
|
||||
{"compType", required_argument, nullptr, 'C'},
|
||||
{"outType", required_argument, nullptr, 'W'},
|
||||
@@ -93,9 +93,9 @@ typedef enum
|
||||
appDouble = 6,
|
||||
} appDataType_t;
|
||||
|
||||
static void check_reduce_dims(const int rank, const std::vector<int>& toReduceDims)
|
||||
static void check_reduce_dims(const int rank, const std::vector<int>& reduceDims)
|
||||
{
|
||||
for(auto dim : toReduceDims)
|
||||
for(auto dim : reduceDims)
|
||||
{
|
||||
if(dim < 0 || dim >= rank)
|
||||
throw std::runtime_error("Invalid dimension index specified for Reducing");
|
||||
@@ -103,7 +103,7 @@ static void check_reduce_dims(const int rank, const std::vector<int>& toReduceDi
|
||||
|
||||
unsigned int flag = 0;
|
||||
|
||||
for(auto dim : toReduceDims)
|
||||
for(auto dim : reduceDims)
|
||||
{
|
||||
if(flag & (0x1 << dim))
|
||||
throw std::runtime_error("All toReduce dimensions should be different!");
|
||||
@@ -122,7 +122,7 @@ class AppArgs
|
||||
|
||||
std::vector<size_t> inLengths;
|
||||
std::vector<size_t> outLengths;
|
||||
std::vector<int> toReduceDims;
|
||||
std::vector<int> reduceDims;
|
||||
|
||||
std::vector<float> scales;
|
||||
|
||||
@@ -152,7 +152,7 @@ class AppArgs
|
||||
std::cout << "Usage of " << cmd << std::endl;
|
||||
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
|
||||
<< std::endl;
|
||||
std::cout << "--toReduceDims or -R, comma separated list of to-reduce dimensions"
|
||||
std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions"
|
||||
<< std::endl;
|
||||
std::cout << "--reduceOp or -O, enum value indicating the reduction operations"
|
||||
<< std::endl;
|
||||
@@ -201,7 +201,7 @@ class AppArgs
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
toReduceDims = getTypeValuesFromString<int>(optarg);
|
||||
reduceDims = getTypeValuesFromString<int>(optarg);
|
||||
break;
|
||||
case 'O':
|
||||
if(!optarg)
|
||||
@@ -321,7 +321,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
|
||||
int rank = args.inLengths.size();
|
||||
|
||||
check_reduce_dims(rank, args.toReduceDims);
|
||||
check_reduce_dims(rank, args.reduceDims);
|
||||
|
||||
if(args.reduceOp == ReduceTensorOp_t::MUL || args.reduceOp == ReduceTensorOp_t::NORM1)
|
||||
throw std::runtime_error("MUL and NORM1 are not supported by composable kernel!");
|
||||
@@ -345,7 +345,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
args.do_dumpout,
|
||||
args.nrepeat,
|
||||
args.inLengths,
|
||||
args.toReduceDims,
|
||||
args.reduceDims,
|
||||
args.reduceOp,
|
||||
args.nanOpt,
|
||||
args.indicesOpt,
|
||||
@@ -360,7 +360,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
args.do_dumpout,
|
||||
args.nrepeat,
|
||||
args.inLengths,
|
||||
args.toReduceDims,
|
||||
args.reduceDims,
|
||||
args.reduceOp,
|
||||
args.nanOpt,
|
||||
args.indicesOpt,
|
||||
@@ -378,7 +378,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
args.do_dumpout,
|
||||
args.nrepeat,
|
||||
args.inLengths,
|
||||
args.toReduceDims,
|
||||
args.reduceDims,
|
||||
args.reduceOp,
|
||||
args.nanOpt,
|
||||
args.indicesOpt,
|
||||
@@ -395,7 +395,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
args.do_dumpout,
|
||||
args.nrepeat,
|
||||
args.inLengths,
|
||||
args.toReduceDims,
|
||||
args.reduceDims,
|
||||
args.reduceOp,
|
||||
args.nanOpt,
|
||||
args.indicesOpt,
|
||||
@@ -410,7 +410,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
args.do_dumpout,
|
||||
args.nrepeat,
|
||||
args.inLengths,
|
||||
args.toReduceDims,
|
||||
args.reduceDims,
|
||||
args.reduceOp,
|
||||
args.nanOpt,
|
||||
args.indicesOpt,
|
||||
|
||||
@@ -1,66 +1,74 @@
|
||||
#!/bin/bash
|
||||
|
||||
PRECISION= ##--half
|
||||
PRECISION=
|
||||
##PRECISION=--half
|
||||
##PRECISION=--double
|
||||
|
||||
if test -n $PRECISION && test "$PRECISION" = "--half"; then
|
||||
CTYPE="-C 1"
|
||||
ACCTYPE="-C 1"
|
||||
else
|
||||
CTYPE=""
|
||||
ACCTYPE=""
|
||||
fi
|
||||
|
||||
WTYPE=
|
||||
driver="./bin/ckProfiler"
|
||||
|
||||
if [ $# -ge 1 ] ; then
|
||||
NREPEAT=$1
|
||||
else
|
||||
NREPEAT=1
|
||||
fi
|
||||
VERIFY="-v $1"
|
||||
INIT=$2
|
||||
NREPEAT=$3
|
||||
|
||||
Operation=7
|
||||
|
||||
#### 0 - ADD, 5 - AVG, 7 - NORM2
|
||||
Operations="0 5 7"
|
||||
|
||||
## for generic validation
|
||||
for op in $Operation; do
|
||||
for op in $Operations; do
|
||||
set -x
|
||||
./bin/ckProfiler reduce $PRECISION -D 64,4,280,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 4,64,280,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 280,4,64,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 4,64,280,82 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 64,280,82,4 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 700,8192 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 700,1024 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 700,4 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
####### datatype layout reduce dims op acctype verify init repeats
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,22960 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,22960 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 4,1469440 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 4,1469440 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
set +x
|
||||
done
|
||||
|
||||
Operation=5
|
||||
#### 0 - ADD, 5 - AVG, 7 - NORM2
|
||||
Operations=5
|
||||
|
||||
## for performance evaluation (resnet50 NHWC => C)
|
||||
for op in $Operation; do
|
||||
for op in $Operations; do
|
||||
set -x
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
|
||||
####### datatype layout reduce dims op acctype verify init repeats
|
||||
$driver reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
|
||||
set +x
|
||||
done
|
||||
|
||||
|
||||
@@ -1,61 +1,69 @@
|
||||
#!/bin/bash
|
||||
|
||||
PRECISION= ##--half
|
||||
PRECISION=
|
||||
##PRECISION=--half
|
||||
##PRECISION=--double
|
||||
|
||||
if [ $# -ge 1 ] ; then
|
||||
NREPEAT=$1
|
||||
else
|
||||
NREPEAT=1
|
||||
fi
|
||||
driver="./bin/ckProfiler"
|
||||
|
||||
Operation=4
|
||||
VERIFY="-v $1"
|
||||
INIT=$2
|
||||
NREPEAT=$3
|
||||
|
||||
LENGTHS=64,4,280,82
|
||||
#### 2 - MIN, 3 - MAX, 4 - AMAX
|
||||
Operations="2 4"
|
||||
|
||||
## for generic validation
|
||||
for op in $Operation; do
|
||||
for op in $Operations; do
|
||||
for use_idx in 0 1; do
|
||||
set -x
|
||||
./bin/ckProfiler reduce $PRECISION -D 64,4,280,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 4,64,280,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 280,4,64,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 4,64,280,82 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 64,280,82,4 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 700,8192 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 700,1024 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 700,4 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
|
||||
####### datatype layout reduce dims op use index verify init repeats
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,22960 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,22960 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 4,1469440 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 4,1469440 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
set +x
|
||||
done
|
||||
done
|
||||
|
||||
Operations=2
|
||||
|
||||
## for performance evaluation (resnet50 NHWC => C)
|
||||
for op in $Operation; do
|
||||
for op in $Operations; do
|
||||
for use_idx in 0 1; do
|
||||
set -x
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
./bin/ckProfiler reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
|
||||
####### datatype layout reduce dims op use index verify init repeats
|
||||
$driver reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
$driver reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
|
||||
set +x
|
||||
done
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user