Pr82 followup (#115)

* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction

* Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter

* Rename the folder name for the pool2d and reduce examples

* Update to reduction test scripts

* Add Readme for pool2d_fwd and reduce_blockwise examples

* Tiny fix in reduce profiler and tiny update in reduce testing scripts

* Tiny fix in testing script profile_reduce_no_index.sh

* Tiny change in script/profile_reduce_with_index.sh

* Renaming and refining in Reduction profiler/device layer/examples

* Renaming and refining in Reduction profiler/device layer/examples

* Renaming all NumReduceDims to NumReduceDim
This commit is contained in:
Qianfeng
2022-03-11 00:14:43 +08:00
committed by GitHub
parent 5d37d7bff4
commit 827301d95a
70 changed files with 1704 additions and 1576 deletions

View File

@@ -0,0 +1,60 @@
# Instructions for ```reduce_blockwise``` Example
## Docker script
```bash
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
```
## Build ```reduce_blockwise```
```bash
mkdir build && cd build
```
```bash
# Need to specify target ID, example below is gfx908
cmake \
-D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
..
```
```bash
make -j reduce_blockwise
```
## Run ```reduce_blockwise```
```bash
# -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
#arg1: initialization (0=no init, 1=integer value, 2=decimal value)
#arg2: run kernel # of times (>1)
./bin/reduce_blockwise -D 16,64,32,960 -v 1 1 10
```
Result
```
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 3 times...
Perf: 0.23536 ms, 267.32 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
error: 0
max_diff: 0, 529, 529
root@dc-smc-18:/data/composable_kernel/Build3# bin/reduce_blockwise -D 16,64,32,960 -v 1 1 10
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 10 times...
Perf: 0.23392 ms, 268.966 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
error: 0
max_diff: 0, 528, 528
```

View File

@@ -14,6 +14,7 @@
#include "device_reduce_blockwise.hpp"
#include "host_reduce_util.hpp"
#include "host_generic_reduction.hpp"
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
@@ -28,8 +29,8 @@ using kInDataType = ck::half_t;
using kOutDataType = ck::half_t;
using kAccDataType = float;
constexpr int Rank = 4;
using ReduceDims_ = ck::Sequence<0, 1, 2>;
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
constexpr ReduceTensorOp_t ReduceOpId = ReduceTensorOp_t::NORM2;
constexpr NanPropagation_t NanOpt = NanPropagation_t::PROPAGATE_NAN;
@@ -46,7 +47,7 @@ using DeviceReduceInstance = DeviceReduceBlockWise<kInDataType,
kAccDataType,
kOutDataType,
Rank,
ReduceDims_,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
@@ -192,39 +193,13 @@ class SimpleAppArgs
};
};
template <int Rank, typename ReduceDims>
static std::vector<int> get_reduce_dims()
{
std::vector<int> resDims;
static_for<0, ReduceDims::Size(), 1>{}([&](auto i) { resDims.push_back(ReduceDims::At(i)); });
return (resDims);
};
template <int Rank, typename ReduceDims>
static std::vector<int> get_invariant_dims()
{
std::vector<int> resDims;
unsigned int incFlag = 0;
static_for<0, ReduceDims::Size(), 1>{}(
[&](auto i) { incFlag = incFlag | (0x1 << ReduceDims::At(i)); });
for(int dim = 0; dim < Rank; dim++)
{
if(incFlag & (0x1 << dim))
continue;
resDims.push_back(dim);
};
return (resDims);
};
int main(int argc, char* argv[])
{
using namespace ck::host_reduce;
const std::vector<int> reduceDims{0, 1, 2};
const std::vector<int> invariantDims{3};
SimpleAppArgs args;
if(args.processArgs(argc, argv) < 0)
@@ -260,15 +235,12 @@ int main(int argc, char* argv[])
Tensor<InDataType> in(args.inLengths);
const std::vector<int> InvariantDims = get_invariant_dims<Rank, ReduceDims_>();
const std::vector<int> ReduceDims = get_reduce_dims<Rank, ReduceDims_>();
std::vector<size_t> outLengths;
if(InvariantDims.empty())
if(invariantDims.empty())
outLengths.push_back(1);
else
for(auto dim : InvariantDims)
for(auto dim : invariantDims)
outLengths.push_back(args.inLengths[dim]);
Tensor<OutDataType> out_ref(outLengths);
@@ -328,7 +300,7 @@ int main(int argc, char* argv[])
if(args.do_verification)
{
ReductionHost<InDataType, AccDataType, OutDataType, ReduceOpId, PropagateNan, NeedIndices>
hostReduce(in.mDesc, out_ref.mDesc, InvariantDims, ReduceDims);
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run(
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
@@ -350,6 +322,7 @@ int main(int argc, char* argv[])
i_inStrides,
i_outLengths,
i_outStrides,
reduceDims,
alpha,
beta,
in_dev.GetDeviceBuffer(),

View File

@@ -0,0 +1,55 @@
# Instructions for ```pool2d_fwd``` Example
## Docker script
```bash
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
```
## Build ```pool2d_fwd```
```bash
mkdir build && cd build
```
```bash
# Need to specify target ID, example below is gfx908
cmake \
-D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
..
```
```bash
make -j pool2d_fwd
```
## Run ```pool2d_fwd```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx
./example/pool2d_fwd 1 1 10
```
Result
```
in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192}
launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1}
Warm up
Start running 10 times...
Perf: 0.415453 ms, 1.37996 TFlops, 749.726 GB/s
error: 0
max_diff: 0, 1, 1
```

View File

@@ -32,57 +32,53 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "cluster_descriptor.hpp"
namespace ck {
template <typename Buffer1dDescType,
typename AccDataType,
template <typename AccDataType,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
bool ReorderThreadClusters,
typename ThreadClusterLengths_M_K,
typename ThreadClusterArrangeOrder,
typename OpReduce,
bool PropagateNan>
struct PartitionedBlockwiseReductionOn1dBuffer
struct PartitionedBlockwiseReduction
{
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
"The product of cluster lengths should be same as BlockSize!");
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
"The buffer size should be the same as BlockSize!");
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename BufferType>
__device__ static void Reduce(BufferType& block_buffer,
AccDataType& accuData,
index_t thread_m_cluster_id,
index_t thread_k_cluster_id)
__device__ static void Reduce(BufferType& block_buffer, AccDataType& accuData)
{
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
if(thread_k_cluster_id < indOffset)
{
// consider the thread clusters order, ensure the contiguous locations are accessed
// by contiguous Thread-ID
index_t offset1 =
ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(make_tuple(
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
index_t offset2 = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize +
(thread_k_cluster_id + indOffset)));
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
@@ -93,34 +89,34 @@ struct PartitionedBlockwiseReductionOn1dBuffer
__syncthreads();
});
index_t offset = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize));
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
accuData = type_convert<AccDataType>(block_buffer[offset]);
};
};
template <typename Buffer1dDescType,
typename AccDataType,
template <typename AccDataType,
typename IndexDataType,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
bool ReorderThreadClusters,
typename ThreadClusterLengths_M_K,
typename ThreadClusterArrangeOrder,
typename OpReduce,
bool PropagateNan>
struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
struct PartitionedBlockwiseReductionWithIndex
{
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
"The product of cluster lengths should be same as BlockSize!");
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
"The buffer size should be the same as BlockSize!");
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
@@ -130,32 +126,24 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
__device__ static void Reduce(BufferType& block_val_buffer,
IdxBufferType& block_idx_buffer,
AccDataType& accuData,
IndexDataType& accuIndex,
index_t thread_m_cluster_id,
index_t thread_k_cluster_id)
IndexDataType& accuIndex)
{
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << I();
if(thread_k_cluster_id % (indOffset * 2) == 0)
{
// consider the thread clusters order, ensure the contiguous locations are accessed
// by contiguous Thread-ID
index_t offset1 =
ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(make_tuple(
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
index_t offset2 = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize +
(thread_k_cluster_id + indOffset)));
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
@@ -170,10 +158,7 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
__syncthreads();
});
index_t offset = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize));
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
accuData = type_convert<AccDataType>(block_val_buffer[offset]);
accuIndex = block_idx_buffer[offset];

View File

@@ -36,14 +36,15 @@ struct DeviceReduce : public BaseOperator
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation& inElementwiseOp,
const AccElementwiseOperation& accElementwiseOp) = 0;
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};

View File

@@ -15,8 +15,8 @@ namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
@@ -40,7 +40,12 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static constexpr bool BetaIsZero = NeedIndices;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
@@ -74,7 +79,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
}
else
{
const auto toReduceDimLengths =
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
@@ -82,7 +87,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
@@ -136,6 +141,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
@@ -144,30 +150,31 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)workspace_dev;
inLengths_ = inLengths;
inStrides_ = inStrides;
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
std::tie(inLengths_, inStrides_) =
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths);
get_2d_lengths<Rank, ReduceDims>(inLengths_);
if constexpr(InvariantDims::Size() == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
@@ -305,6 +312,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const void* in_dev,
@@ -318,6 +326,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),

View File

@@ -15,8 +15,8 @@ namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
@@ -45,7 +45,11 @@ struct DeviceReduceBlockWiseSecondCall
std::is_same<InDataType, AccDataType>::value,
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
@@ -117,16 +121,16 @@ struct DeviceReduceBlockWiseSecondCall
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
: inLengths_(inLengths),
inStrides_(inStrides),
outLengths_(outLengths),
outStrides_(outStrides),
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op)
{
inLengths_ = inLengths;
inStrides_ = inStrides;
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta);
@@ -268,6 +272,7 @@ struct DeviceReduceBlockWiseSecondCall
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const void* in_dev,
@@ -277,6 +282,8 @@ struct DeviceReduceBlockWiseSecondCall
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op) override
{
(void)reduceDims;
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,

View File

@@ -2,6 +2,7 @@
#define DEVICE_REDUCE_COMMON_HPP
#include <vector>
#include <cassert>
#include "common_header.hpp"
#include "reduction_enums.hpp"
@@ -40,23 +41,6 @@ constexpr bool belong()
return (inside);
};
template <int Rank, typename ReduceDims, int start = 0>
constexpr auto get_invariant_dims()
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
if constexpr(start >= Rank)
return Sequence<>{};
else
{
if constexpr(!belong<start, ReduceDims>())
return merge_sequences(Sequence<start>{},
get_invariant_dims<Rank, ReduceDims, start + 1>());
else
return get_invariant_dims<Rank, ReduceDims, start + 1>();
};
};
// helper functions using variadic template arguments
template <index_t... Ns>
static auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
@@ -74,6 +58,45 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
return make_tuple_from_array_and_index_seq(lengths, index_seq);
};
template <index_t Rank, index_t NumReduceDim>
static inline std::pair<std::vector<int>, std::vector<int>>
shuffle_tensor_dimensions(const std::vector<int>& dimLengths,
const std::vector<int>& dimStrides,
const std::vector<int>& reduceDims)
{
std::vector<int> newDimLengths;
std::vector<int> newDimStrides;
assert(Rank == dimLengths.size() && Rank == dimStrides.size() &&
NumReduceDim == reduceDims.size());
int reduceFlag = 0;
// flag the bits for the reduceDims
for(int i = 0; i < NumReduceDim; i++)
{
reduceFlag |= 1 << reduceDims[i];
};
// collect invariant dimensions
for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) == 0)
{
newDimLengths.push_back(dimLengths[i]);
newDimStrides.push_back(dimStrides[i]);
};
// collect reduce dimensions
for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) > 0)
{
newDimLengths.push_back(dimLengths[i]);
newDimStrides.push_back(dimStrides[i]);
};
return std::make_pair(newDimLengths, newDimStrides);
};
} // namespace device
} // namespace tensor_operation

View File

@@ -17,8 +17,8 @@ namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockAtomicAdd
using IndexDataType = int32_t;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
@@ -84,7 +89,7 @@ struct DeviceReduceMultiBlockAtomicAdd
}
else
{
const auto toReduceDimLengths =
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
@@ -92,7 +97,7 @@ struct DeviceReduceMultiBlockAtomicAdd
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
@@ -147,6 +152,7 @@ struct DeviceReduceMultiBlockAtomicAdd
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
@@ -155,31 +161,31 @@ struct DeviceReduceMultiBlockAtomicAdd
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev}
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)out_indices_dev;
(void)workspace_dev;
inLengths_ = inLengths;
inStrides_ = inStrides;
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
std::tie(inLengths_, inStrides_) =
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths);
get_2d_lengths<Rank, ReduceDims>(inLengths_);
if constexpr(InvariantDims::Size() == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
int iterations = 1;
while(true)
@@ -369,6 +375,7 @@ struct DeviceReduceMultiBlockAtomicAdd
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const void* in_dev,
@@ -382,6 +389,7 @@ struct DeviceReduceMultiBlockAtomicAdd
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),

View File

@@ -15,8 +15,8 @@ namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockPartialReduce
using IndexDataType = int32_t;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
@@ -112,7 +117,7 @@ struct DeviceReduceMultiBlockPartialReduce
}
else
{
const auto toReduceDimLengths =
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
@@ -120,7 +125,7 @@ struct DeviceReduceMultiBlockPartialReduce
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
@@ -161,10 +166,11 @@ struct DeviceReduceMultiBlockPartialReduce
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides,
Argument(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
@@ -173,31 +179,30 @@ struct DeviceReduceMultiBlockPartialReduce
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev},
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
workspace_dev_{workspace_dev}
workspace_dev_{workspace_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
inLengths_ = inLengths;
inStrides_ = inStrides;
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
std::tie(inLengths_, inStrides_) =
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths);
get_2d_lengths<Rank, ReduceDims>(inLengths_);
if constexpr(InvariantDims::Size() == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
int iterations = 1;
while(true)
@@ -370,6 +375,7 @@ struct DeviceReduceMultiBlockPartialReduce
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const void* in_dev,
@@ -383,6 +389,7 @@ struct DeviceReduceMultiBlockPartialReduce
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),

View File

@@ -16,7 +16,7 @@ template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
typename ReduceDims,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
@@ -40,7 +40,12 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static constexpr bool BetaIsZero = NeedIndices;
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using InvariantDims =
typename conditional<NumInvariantDim == 0,
Sequence<>,
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
static constexpr index_t srcDims = Rank;
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
@@ -74,7 +79,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
}
else
{
const auto toReduceDimLengths =
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
@@ -82,7 +87,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(toReduceDimLengths)),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
@@ -136,6 +141,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
@@ -144,30 +150,32 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op)
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)workspace_dev;
inLengths_ = inLengths;
inStrides_ = inStrides;
outLengths_ = outLengths;
outStrides_ = outStrides;
in_elementwise_op_ = in_elementwise_op;
acc_elementwise_op_ = acc_elementwise_op;
std::tie(inLengths_, inStrides_) =
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<OutDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, ReduceDims>(inLengths);
get_2d_lengths<Rank, ReduceDims>(inLengths_);
if constexpr(InvariantDims::Size() == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
@@ -306,6 +314,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
const std::vector<int>& reduceDims,
float alpha,
float beta,
const void* in_dev,
@@ -319,6 +328,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),

View File

@@ -31,8 +31,8 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
namespace ck {
@@ -158,13 +158,27 @@ struct GridwiseReduction_mk_to_m_blockwise
{
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr auto buffer_1d_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
// Dim_K as the fastest one
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
@@ -180,14 +194,12 @@ struct GridwiseReduction_mk_to_m_blockwise
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
AccDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ReduceOperation,
PropagateNan>;
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
@@ -221,31 +233,31 @@ struct GridwiseReduction_mk_to_m_blockwise
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
@@ -283,21 +295,14 @@ struct GridwiseReduction_mk_to_m_blockwise
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(reorder_thread_cluster)
{
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
accu_value_buf[I];
}
else
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
accu_value_buf[I];
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
@@ -380,15 +385,13 @@ struct GridwiseReduction_mk_to_m_blockwise
IndexDataType* const __restrict__ p_indices_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
AccDataType,
IndexDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ReduceOperation,
PropagateNan>;
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
@@ -432,31 +435,31 @@ struct GridwiseReduction_mk_to_m_blockwise
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
index_t indexOffset = 0;
@@ -503,29 +506,15 @@ struct GridwiseReduction_mk_to_m_blockwise
});
// store thread local value to LDS for parallel reduction
if constexpr(reorder_thread_cluster)
{
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpIndex;
}
else
{
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpIndex;
}
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
block_reduce_idx_buf,
tmpValue,
tmpIndex,
thread_m_cluster_id,
thread_k_cluster_id);
BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
@@ -648,15 +637,13 @@ struct GridwiseReduction_mk_to_m_blockwise
IndexDataType* const __restrict__ p_indices_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
AccDataType,
IndexDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ReduceOperation,
PropagateNan>;
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
@@ -707,46 +694,48 @@ struct GridwiseReduction_mk_to_m_blockwise
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<
IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
// index_t indexOffset = 0;
@@ -787,29 +776,15 @@ struct GridwiseReduction_mk_to_m_blockwise
});
// store thread local value to LDS for parallel reduction
if constexpr(reorder_thread_cluster)
{
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpIndex;
}
else
{
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpIndex;
}
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
block_reduce_idx_buf,
tmpValue,
tmpIndex,
thread_m_cluster_id,
thread_k_cluster_id);
BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);

View File

@@ -86,22 +86,34 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
{
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr auto buffer_1d_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using blockwise_reduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
AccDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ReduceOperation,
PropagateNan>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
// Dim_K as the fastest one
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
@@ -145,12 +157,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
@@ -158,17 +170,16 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
@@ -212,21 +223,14 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(reorder_thread_cluster)
{
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
accu_value_buf[I];
}
else
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
accu_value_buf[I];
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
blockwise_reduce::Reduce(
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {

View File

@@ -30,8 +30,8 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
namespace ck {
@@ -103,13 +103,27 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
{
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr auto buffer1dDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
// Dim_K as the fastest one
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
@@ -124,14 +138,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType* const __restrict__ p_ws_values_global,
IndexDataType* const __restrict__ p_ws_indices_global)
{
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer1dDesc),
AccDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ReduceOperation,
PropagateNan>;
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
@@ -168,12 +180,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
@@ -181,17 +193,16 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
@@ -233,21 +244,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
// Each block executes multiple parallel reductions on the LDS, and due to the using of
// vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(reorder_thread_cluster)
{
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
accu_value_buf[I];
}
else
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
accu_value_buf[I];
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
if(thread_k_cluster_id == 0)
@@ -290,15 +294,13 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
IndexDataType* const __restrict__ p_ws_indices_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer1dDesc),
AccDataType,
IndexDataType,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
reorder_thread_cluster,
ReduceOperation,
PropagateNan>;
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
@@ -346,12 +348,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const index_t thread_m_cluster_id =
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
const index_t thread_k_cluster_id =
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
: thread_local_id % KThreadClusterSize;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
@@ -359,17 +361,16 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
@@ -418,29 +419,15 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
});
// store thread local value to LDS for parallel reduction
if constexpr(reorder_thread_cluster)
{
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
thread_m_cluster_id) = tmpIndex;
}
else
{
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpValue;
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
thread_k_cluster_id) = tmpIndex;
}
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
block_reduce_idx_buf,
tmpValue,
tmpIndex,
thread_m_cluster_id,
thread_k_cluster_id);
BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);

View File

@@ -101,6 +101,9 @@ template <typename InDataType,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_threadwise
{
using ThreadBufferDimAccessOrder =
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
template <typename T>
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
@@ -147,17 +150,17 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
@@ -299,17 +302,17 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);

View File

@@ -57,7 +57,7 @@ template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
int NumReduceDim,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
@@ -91,7 +91,7 @@ void add_device_reduce_instance_blockwise(
AccDataType,
OutDataType,
Rank,
ReduceDims,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
@@ -112,34 +112,36 @@ void add_device_reduce_instance_blockwise(
});
};
#define ADD_BLOCKWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
#define ADD_BLOCKWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
#define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
@@ -149,15 +151,16 @@ void add_device_reduce_instance_blockwise(
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
NumReduceDim)
} // namespace device_reduce_instance
} // namespace device

View File

@@ -11,25 +11,25 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -45,7 +45,7 @@ template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
int NumReduceDim,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
@@ -86,7 +86,7 @@ void add_device_reduce_instance_blockwise_second_call(
AccDataType,
OutDataType,
Rank,
ReduceDims,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
@@ -106,21 +106,21 @@ void add_device_reduce_instance_blockwise_second_call(
});
};
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceBlockWiseSecondCallPtrType<compT, ReduceOpId>> & \
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceBlockWiseSecondCallPtrType<compT, ReduceOpId>> & \
device_op_instances)
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE(inT, \
compT, \
outT, \
@@ -128,27 +128,27 @@ void add_device_reduce_instance_blockwise_second_call(
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
NumReduceDim)
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
extern template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector< \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
AccElementwiseOperation>> & \
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector< \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
@@ -156,7 +156,7 @@ void add_device_reduce_instance_blockwise_second_call(
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
NumReduceDim)
} // namespace device_reduce_instance
} // namespace device

View File

@@ -11,25 +11,25 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -59,7 +59,7 @@ template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
int NumReduceDim,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
@@ -110,7 +110,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
AccDataType,
OutDataType,
Rank,
ReduceDims,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
@@ -132,21 +132,21 @@ void add_device_reduce_instance_multiblock_atomic_add(
}
};
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \
compT, \
outT, \
@@ -154,15 +154,15 @@ void add_device_reduce_instance_multiblock_atomic_add(
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
NumReduceDim)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
@@ -173,7 +173,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
@@ -181,7 +181,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
NumReduceDim)
} // namespace device_reduce_instance
} // namespace device

View File

@@ -11,13 +11,13 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,13 +11,13 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,13 +11,13 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -55,7 +55,7 @@ template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
int NumReduceDim,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
@@ -93,7 +93,7 @@ void add_device_reduce_instance_multiblock_partial_reduce(
AccDataType,
OutDataType,
Rank,
ReduceDims,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
@@ -113,21 +113,21 @@ void add_device_reduce_instance_multiblock_partial_reduce(
});
};
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceMultiBlockPartialReducePtrType<compT, ReduceOpId>> & \
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceMultiBlockPartialReducePtrType<compT, ReduceOpId>> & \
device_op_instances)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE(inT, \
compT, \
outT, \
@@ -135,28 +135,27 @@ void add_device_reduce_instance_multiblock_partial_reduce(
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
NumReduceDim)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
extern template void \
add_device_reduce_instance_multiblock_partial_reduce<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector< \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector< \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
@@ -164,7 +163,7 @@ void add_device_reduce_instance_multiblock_partial_reduce(
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
NumReduceDim)
} // namespace device_reduce_instance
} // namespace device

View File

@@ -11,25 +11,25 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,29 +11,29 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,10 +11,10 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,37 +11,37 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
// Will be moved to use MultiBlockAtomicAdd
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -57,7 +57,7 @@ template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
int NumReduceDim,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
@@ -89,7 +89,7 @@ void add_device_reduce_instance_threadwise(
AccDataType,
OutDataType,
Rank,
ReduceDims,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
@@ -108,34 +108,36 @@ void add_device_reduce_instance_threadwise(
});
};
#define ADD_THREADWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
#define ADD_THREADWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_THREADWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
#define ADD_THREADWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_THREADWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
NumReduceDim)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
@@ -145,15 +147,16 @@ void add_device_reduce_instance_threadwise(
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_THREADWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
#define ADD_THREADWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_THREADWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
NumReduceDim)
} // namespace device_reduce_instance
} // namespace device

View File

@@ -11,25 +11,25 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,16 +11,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -11,34 +11,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,25 +6,25 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,16 +6,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,34 +6,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,16 +6,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,34 +6,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,25 +6,25 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,16 +6,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,34 +6,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,16 +6,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,34 +6,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,13 +6,13 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,13 +6,13 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,13 +6,13 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,25 +6,25 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,16 +6,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,29 +6,29 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,10 +6,10 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,37 +6,37 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
// Will be moved to use MultiBlockAtomicAdd
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,25 +6,25 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,16 +6,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,34 +6,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,16 +6,16 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -6,34 +6,34 @@ namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0);
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance

View File

@@ -9,54 +9,52 @@ namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
template <int Rank, typename ReduceDims, int ReduceOpId, int NanOpt, int IndicesOpt>
template <int Rank, int NumReduceDim, int ReduceOpId, int NanOpt, int IndicesOpt>
struct ReduceDescription
{
static constexpr int Rank_ = Rank;
static constexpr int ReduceOpId_ = ReduceOpId;
static constexpr int NanOpt_ = NanOpt;
static constexpr int IndicesOpt_ = IndicesOpt;
using ReduceDims_ = ReduceDims;
static constexpr int Rank_ = Rank;
static constexpr int NumReduceDim_ = NumReduceDim;
static constexpr int ReduceOpId_ = ReduceOpId;
static constexpr int NanOpt_ = NanOpt;
static constexpr int IndicesOpt_ = IndicesOpt;
};
using reduce_description_instances =
std::tuple<ReduceDescription<4, Sequence<0, 1, 2>, 0, 0, 0>, // for ADD
ReduceDescription<4, Sequence<0>, 0, 0, 0>,
ReduceDescription<2, Sequence<1>, 0, 0, 0>,
using reduce_description_instances = std::tuple<ReduceDescription<4, 3, 0, 0, 0>, // for ADD
ReduceDescription<4, 1, 0, 0, 0>,
ReduceDescription<2, 1, 0, 0, 0>,
ReduceDescription<4, Sequence<0, 1, 2>, 5, 0, 0>, // for AVG
ReduceDescription<4, Sequence<0>, 5, 0, 0>,
ReduceDescription<2, Sequence<1>, 5, 0, 0>,
ReduceDescription<4, 3, 5, 0, 0>, // for AVG
ReduceDescription<4, 1, 5, 0, 0>,
ReduceDescription<2, 1, 5, 0, 0>,
ReduceDescription<4, Sequence<0, 1, 2>, 7, 0, 0>, // for NORM2
ReduceDescription<4, Sequence<0>, 7, 0, 0>,
ReduceDescription<2, Sequence<1>, 7, 0, 0>,
ReduceDescription<4, 3, 7, 0, 0>, // for NORM2
ReduceDescription<4, 1, 7, 0, 0>,
ReduceDescription<2, 1, 7, 0, 0>,
ReduceDescription<4, Sequence<0, 1, 2>, 2, 0, 0>, // for MIN
ReduceDescription<4, Sequence<0>, 2, 0, 0>,
ReduceDescription<2, Sequence<1>, 2, 0, 0>,
ReduceDescription<4, Sequence<0, 1, 2>, 3, 0, 0>, // for MAX
ReduceDescription<4, Sequence<0>, 3, 0, 0>,
ReduceDescription<2, Sequence<1>, 3, 0, 0>,
ReduceDescription<4, Sequence<0, 1, 2>, 4, 0, 0>, // for AMAX
ReduceDescription<4, Sequence<0>, 4, 0, 0>,
ReduceDescription<2, Sequence<1>, 4, 0, 0>,
ReduceDescription<4, 3, 2, 0, 0>, // for MIN
ReduceDescription<4, 1, 2, 0, 0>,
ReduceDescription<2, 1, 2, 0, 0>,
ReduceDescription<4, 3, 3, 0, 0>, // for MAX
ReduceDescription<4, 1, 3, 0, 0>,
ReduceDescription<2, 1, 3, 0, 0>,
ReduceDescription<4, 3, 4, 0, 0>, // for AMAX
ReduceDescription<4, 1, 4, 0, 0>,
ReduceDescription<2, 1, 4, 0, 0>,
ReduceDescription<4, Sequence<0, 1, 2>, 2, 0, 1>, // for MIN
ReduceDescription<4, Sequence<0>, 2, 0, 1>,
ReduceDescription<2, Sequence<1>, 2, 0, 1>,
ReduceDescription<4, Sequence<0, 1, 2>, 3, 0, 1>, // for MAX
ReduceDescription<4, Sequence<0>, 3, 0, 1>,
ReduceDescription<2, Sequence<1>, 3, 0, 1>,
ReduceDescription<4, Sequence<0, 1, 2>, 4, 0, 1>, // for AMAX
ReduceDescription<4, Sequence<0>, 4, 0, 1>,
ReduceDescription<2, Sequence<1>, 4, 0, 1>>;
ReduceDescription<4, 3, 2, 0, 1>, // for MIN
ReduceDescription<4, 1, 2, 0, 1>,
ReduceDescription<2, 1, 2, 0, 1>,
ReduceDescription<4, 3, 3, 0, 1>, // for MAX
ReduceDescription<4, 1, 3, 0, 1>,
ReduceDescription<2, 1, 3, 0, 1>,
ReduceDescription<4, 3, 4, 0, 1>, // for AMAX
ReduceDescription<4, 1, 4, 0, 1>,
ReduceDescription<2, 1, 4, 0, 1>>;
template <typename DescriptionType>
bool description_match(const DescriptionType& description,
int Rank,
const std::vector<int>& ReduceDims,
const std::vector<int>& reduceDims,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt)
@@ -66,16 +64,11 @@ bool description_match(const DescriptionType& description,
description.IndicesOpt_ != static_cast<int>(IndicesOpt))
return (false);
if(DescriptionType::ReduceDims_::Size() != ReduceDims.size())
if(DescriptionType::NumReduceDim_ != reduceDims.size())
return (false);
bool result = true;
static_for<0, DescriptionType::ReduceDims_::Size(), 1>{}([&](auto i) {
if(DescriptionType::ReduceDims_::At(i) != ReduceDims[i])
result = false;
});
return (result);
};
@@ -87,33 +80,29 @@ bool description_match(const DescriptionType& description,
namespace ck {
namespace profiler {
template <int Rank, typename ReduceDims>
static std::vector<int> get_reduce_dims()
template <index_t Rank, index_t NumReduceDim>
static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduceDims)
{
std::vector<int> resDims;
assert(NumReduceDim == reduceDims.size());
static_for<0, ReduceDims::Size(), 1>{}([&](auto i) { resDims.push_back(ReduceDims::At(i)); });
int reduceFlag = 0;
return (resDims);
};
template <int Rank, typename ReduceDims>
static std::vector<int> get_invariant_dims()
{
std::vector<int> resDims;
unsigned int incFlag = 0;
static_for<0, ReduceDims::Size(), 1>{}(
[&](auto i) { incFlag = incFlag | (0x1 << ReduceDims::At(i)); });
for(int dim = 0; dim < Rank; dim++)
// flag the bits for the reduceDims
for(int i = 0; i < NumReduceDim; i++)
{
if(incFlag & (0x1 << dim))
continue;
resDims.push_back(dim);
reduceFlag |= 1 << reduceDims[i];
};
return (resDims);
std::vector<int> invariantDims;
// collect invariant dimensions
for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) == 0)
{
invariantDims.push_back(i);
};
return invariantDims;
};
template <typename T>
@@ -149,7 +138,7 @@ template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims_,
int NumReduceDim,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
@@ -159,6 +148,7 @@ void profile_reduce_impl_impl(bool do_verification,
bool do_dumpout,
int nrepeat,
const std::vector<size_t>& inLengths,
const std::vector<int>& reduceDims,
float alpha,
float beta)
{
@@ -203,15 +193,14 @@ void profile_reduce_impl_impl(bool do_verification,
{
Tensor<InDataType> in(inLengths);
const std::vector<int> OuterDims = get_invariant_dims<Rank, ReduceDims_>();
const std::vector<int> ReduceDims = get_reduce_dims<Rank, ReduceDims_>();
std::vector<size_t> outLengths;
if(OuterDims.empty())
const auto invariantDims = get_invariant_dims<Rank, NumReduceDim>(reduceDims);
if(reduceDims.size() == Rank)
outLengths.push_back(1);
else
for(auto dim : OuterDims)
for(auto dim : invariantDims)
outLengths.push_back(inLengths[dim]);
Tensor<OutDataType> out_ref(outLengths);
@@ -302,7 +291,7 @@ void profile_reduce_impl_impl(bool do_verification,
AccDataType,
OutDataType,
Rank,
ReduceDims_,
NumReduceDim,
ReduceOpId,
NanOpt,
IndicesOpt>(reduce0_ptrs);
@@ -311,7 +300,7 @@ void profile_reduce_impl_impl(bool do_verification,
AccDataType,
OutDataType,
Rank,
ReduceDims_,
NumReduceDim,
ReduceOpId,
NanOpt,
IndicesOpt>(reduce0_ptrs);
@@ -321,7 +310,7 @@ void profile_reduce_impl_impl(bool do_verification,
AccDataType,
OutDataType,
Rank,
ReduceDims_,
NumReduceDim,
ReduceOpId,
NanOpt,
IndicesOpt>(reduce0_ptrs);
@@ -330,7 +319,7 @@ void profile_reduce_impl_impl(bool do_verification,
AccDataType,
OutDataType,
Rank,
ReduceDims_,
NumReduceDim,
ReduceOpId,
NanOpt,
IndicesOpt>(reduce1_ptrs);
@@ -341,7 +330,7 @@ void profile_reduce_impl_impl(bool do_verification,
AccDataType,
OutDataType,
Rank,
ReduceDims_,
NumReduceDim,
ReduceOpId,
NanOpt,
IndicesOpt>(reduce2_ptrs);
@@ -358,7 +347,7 @@ void profile_reduce_impl_impl(bool do_verification,
using hCompType = typename type_mapping<AccDataType>::outDataType;
ReductionHost<hInType, hCompType, hOutType, ReduceOpId, PropagateNan, NeedIndices>
hostReduce(in.mDesc, out_ref.mDesc, OuterDims, ReduceDims);
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run(alpha,
reinterpret_cast<const hInType*>(in.mData.data()),
@@ -383,6 +372,7 @@ void profile_reduce_impl_impl(bool do_verification,
i_inStrides,
i_outLengths,
i_outStrides,
reduceDims,
alpha,
beta,
in_dev.GetDeviceBuffer(),
@@ -464,6 +454,7 @@ void profile_reduce_impl_impl(bool do_verification,
i_inStrides,
i_outLengths,
i_outStrides,
reduceDims,
alpha,
beta,
in_dev.GetDeviceBuffer(),
@@ -496,6 +487,7 @@ void profile_reduce_impl_impl(bool do_verification,
inStrides2,
i_outLengths,
i_outStrides,
reduceDims,
alpha,
beta,
ws_dev.GetDeviceBuffer(),
@@ -584,7 +576,7 @@ void profile_reduce_impl(bool do_verification,
bool do_dumpout,
int nrepeat,
const std::vector<size_t>& inLengths,
const std::vector<int>& ReduceDims,
const std::vector<int>& reduceDims,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt,
@@ -605,18 +597,26 @@ void profile_reduce_impl(bool do_verification,
using descType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
if(!description_match(
descType{}, inLengths.size(), ReduceDims, ReduceOpId, NanOpt, IndicesOpt))
descType{}, inLengths.size(), reduceDims, ReduceOpId, NanOpt, IndicesOpt))
return;
profile_reduce_impl_impl<InDataType,
AccDataType,
OutDataType,
descType::Rank_,
typename descType::ReduceDims_,
descType::NumReduceDim_,
static_cast<ReduceTensorOp_t>(descType::ReduceOpId_),
static_cast<NanPropagation_t>(descType::NanOpt_),
static_cast<ReduceTensorIndices_t>(descType::IndicesOpt_)>(
do_verification, init_method, do_log, do_dumpout, nrepeat, inLengths, alpha, beta);
do_verification,
init_method,
do_log,
do_dumpout,
nrepeat,
inLengths,
reduceDims,
alpha,
beta);
matched = true;
});

View File

@@ -25,7 +25,7 @@ using ck::ReduceTensorIndices_t;
using ck::ReduceTensorOp_t;
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
{"toReduceDims", required_argument, nullptr, 'R'},
{"reduceDims", required_argument, nullptr, 'R'},
{"reduceOp", required_argument, nullptr, 'O'},
{"compType", required_argument, nullptr, 'C'},
{"outType", required_argument, nullptr, 'W'},
@@ -93,9 +93,9 @@ typedef enum
appDouble = 6,
} appDataType_t;
static void check_reduce_dims(const int rank, const std::vector<int>& toReduceDims)
static void check_reduce_dims(const int rank, const std::vector<int>& reduceDims)
{
for(auto dim : toReduceDims)
for(auto dim : reduceDims)
{
if(dim < 0 || dim >= rank)
throw std::runtime_error("Invalid dimension index specified for Reducing");
@@ -103,7 +103,7 @@ static void check_reduce_dims(const int rank, const std::vector<int>& toReduceDi
unsigned int flag = 0;
for(auto dim : toReduceDims)
for(auto dim : reduceDims)
{
if(flag & (0x1 << dim))
throw std::runtime_error("All toReduce dimensions should be different!");
@@ -122,7 +122,7 @@ class AppArgs
std::vector<size_t> inLengths;
std::vector<size_t> outLengths;
std::vector<int> toReduceDims;
std::vector<int> reduceDims;
std::vector<float> scales;
@@ -152,7 +152,7 @@ class AppArgs
std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
<< std::endl;
std::cout << "--toReduceDims or -R, comma separated list of to-reduce dimensions"
std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions"
<< std::endl;
std::cout << "--reduceOp or -O, enum value indicating the reduction operations"
<< std::endl;
@@ -201,7 +201,7 @@ class AppArgs
if(!optarg)
throw std::runtime_error("Invalid option format!");
toReduceDims = getTypeValuesFromString<int>(optarg);
reduceDims = getTypeValuesFromString<int>(optarg);
break;
case 'O':
if(!optarg)
@@ -321,7 +321,7 @@ int profile_reduce(int argc, char* argv[])
int rank = args.inLengths.size();
check_reduce_dims(rank, args.toReduceDims);
check_reduce_dims(rank, args.reduceDims);
if(args.reduceOp == ReduceTensorOp_t::MUL || args.reduceOp == ReduceTensorOp_t::NORM1)
throw std::runtime_error("MUL and NORM1 are not supported by composable kernel!");
@@ -345,7 +345,7 @@ int profile_reduce(int argc, char* argv[])
args.do_dumpout,
args.nrepeat,
args.inLengths,
args.toReduceDims,
args.reduceDims,
args.reduceOp,
args.nanOpt,
args.indicesOpt,
@@ -360,7 +360,7 @@ int profile_reduce(int argc, char* argv[])
args.do_dumpout,
args.nrepeat,
args.inLengths,
args.toReduceDims,
args.reduceDims,
args.reduceOp,
args.nanOpt,
args.indicesOpt,
@@ -378,7 +378,7 @@ int profile_reduce(int argc, char* argv[])
args.do_dumpout,
args.nrepeat,
args.inLengths,
args.toReduceDims,
args.reduceDims,
args.reduceOp,
args.nanOpt,
args.indicesOpt,
@@ -395,7 +395,7 @@ int profile_reduce(int argc, char* argv[])
args.do_dumpout,
args.nrepeat,
args.inLengths,
args.toReduceDims,
args.reduceDims,
args.reduceOp,
args.nanOpt,
args.indicesOpt,
@@ -410,7 +410,7 @@ int profile_reduce(int argc, char* argv[])
args.do_dumpout,
args.nrepeat,
args.inLengths,
args.toReduceDims,
args.reduceDims,
args.reduceOp,
args.nanOpt,
args.indicesOpt,

View File

@@ -1,66 +1,74 @@
#!/bin/bash
PRECISION= ##--half
PRECISION=
##PRECISION=--half
##PRECISION=--double
if test -n $PRECISION && test "$PRECISION" = "--half"; then
CTYPE="-C 1"
ACCTYPE="-C 1"
else
CTYPE=""
ACCTYPE=""
fi
WTYPE=
driver="./bin/ckProfiler"
if [ $# -ge 1 ] ; then
NREPEAT=$1
else
NREPEAT=1
fi
VERIFY="-v $1"
INIT=$2
NREPEAT=$3
Operation=7
#### 0 - ADD, 5 - AVG, 7 - NORM2
Operations="0 5 7"
## for generic validation
for op in $Operation; do
for op in $Operations; do
set -x
./bin/ckProfiler reduce $PRECISION -D 64,4,280,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 4,64,280,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 280,4,64,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 4,64,280,82 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 64,280,82,4 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 700,8192 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 700,1024 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 700,4 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
####### datatype layout reduce dims op acctype verify init repeats
$driver reduce $PRECISION -D 64,4,280,82 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,22960 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,22960 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 4,1469440 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 4,1469440 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
set +x
done
Operation=5
#### 0 - ADD, 5 - AVG, 7 - NORM2
Operations=5
## for performance evaluation (resnet50 NHWC => C)
for op in $Operation; do
for op in $Operations; do
set -x
./bin/ckProfiler reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op $CTYPE $WTYPE -v 1 1 $NREPEAT
####### datatype layout reduce dims op acctype verify init repeats
$driver reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT
set +x
done

View File

@@ -1,61 +1,69 @@
#!/bin/bash
PRECISION= ##--half
PRECISION=
##PRECISION=--half
##PRECISION=--double
if [ $# -ge 1 ] ; then
NREPEAT=$1
else
NREPEAT=1
fi
driver="./bin/ckProfiler"
Operation=4
VERIFY="-v $1"
INIT=$2
NREPEAT=$3
LENGTHS=64,4,280,82
#### 2 - MIN, 3 - MAX, 4 - AMAX
Operations="2 4"
## for generic validation
for op in $Operation; do
for op in $Operations; do
for use_idx in 0 1; do
set -x
./bin/ckProfiler reduce $PRECISION -D 64,4,280,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 4,64,280,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 280,4,64,82 -R 0 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 4,64,280,82 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 64,280,82,4 -R 0,1,2 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 700,8192 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 700,1024 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 700,4 -R 1 -O $op $CTYPE -v 1 1 $NREPEAT
####### datatype layout reduce dims op use index verify init repeats
$driver reduce $PRECISION -D 64,4,280,82 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,22960 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,22960 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 4,1469440 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 4,1469440 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
set +x
done
done
Operations=2
## for performance evaluation (resnet50 NHWC => C)
for op in $Operation; do
for op in $Operations; do
for use_idx in 0 1; do
set -x
./bin/ckProfiler reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
./bin/ckProfiler reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op -I $use_idx -v 1 1 $NREPEAT
####### datatype layout reduce dims op use index verify init repeats
$driver reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
$driver reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT
set +x
done
done