mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Reduction for int8 and bfloat16 (#125)
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction
* Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter
* Rename the folder name for the pool2d and reduce examples
* Update to reduction test scripts
* Add Readme for pool2d_fwd and reduce_blockwise examples
* Add support for int8_t reduction (ADD/AVG, MIN/MAX/AMAX)
* Tiny fix in reduce profiler and tiny update in reduce testing scripts
* Tiny fix in testing script profile_reduce_no_index.sh
* Tiny fix in testing script profile_reduce_no_index.sh
* Add support for bfp16 reduction (using bhalf_t = ushort)
* Tiny fix in amd_buffer_addressing.hpp
* Tiny change in script/profile_reduce_with_index.sh
* Use AccDataType for Beta value and use element_wise::PassThrough
* Use type_convert for type converting in host layer reduction
* Renaming and refining in Reduction profiler/device layer/examples
* Renaming and refining in Reduction profiler/device layer/examples
* Renaming all NumReduceDims to NumReduceDim
* Fix the leaked type_convert in ThreadwiseTensorSliceTransfer_v2
* Update to testing scripts to add bf16 support
* added more static_assert
* Remove buggy tunable configurations defined in device_reduce_instance_xxx.hpp
* Add static_assert to give compile-time warning for incorrect thread slice-size/vector-size configurations
* minor change
* Refine and fix (in GetWorkspaceSizeInBytes of MultiBlockPartialReduce) to make int8 completely pass
* Tiny renaming in gridwise_2d_reduction_multiblock_partial_reduce.hpp
* Tiny fix in script/profile_reduce_no_index.sh
* Refine in DeviceReduce layer with regard to using NumInvariantDim/NumReduceDim or InvariantDims/ReduceDims
* Generic renaming in host reduction and DeviceReduce layer
* Add support for 4-d all dimension reduction in the profiler and add_device_reduce_xxx instances
* Use multi-thread and simplification for host Reduction implementation
* Add ctest for reduction
* Update to clarify the using of data init method in produce_reduce/example_reduce/test_reduce/
* Update to the reduce CTest executables to enable default testing behavior when no command argument
* Renaming
Co-authored-by: Jianfeng yan <jfyan008@gmail.com>
[ROCm/composable_kernel commit: 9a8ee8a39a]
This commit is contained in:
@@ -37,7 +37,7 @@ cmake \
|
||||
```bash
|
||||
# -D <xxx> : input 4-d tensor lengths
|
||||
# -v <x> : verification (0=no, 1=yes)
|
||||
#arg1: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg2: run kernel # of times (>1)
|
||||
./bin/reduce_blockwise -D 16,64,32,960 -v 1 1 10
|
||||
```
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "device_base.hpp"
|
||||
#include "device_reduce_blockwise.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
#include "host_generic_reduction.hpp"
|
||||
#include "host_reduction.hpp"
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
@@ -21,13 +21,13 @@
|
||||
using namespace ck;
|
||||
using namespace ck::tensor_operation::device;
|
||||
|
||||
using InDataType = half_float::half;
|
||||
using OutDataType = half_float::half;
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using kInDataType = ck::half_t;
|
||||
using kOutDataType = ck::half_t;
|
||||
using kAccDataType = float;
|
||||
using HostInDataType = half_float::half;
|
||||
using HostOutDataType = half_float::half;
|
||||
using HostAccDataType = float;
|
||||
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
@@ -43,9 +43,9 @@ using InElementwiseOperation =
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
using DeviceReduceInstance = DeviceReduceBlockWise<kInDataType,
|
||||
kAccDataType,
|
||||
kOutDataType,
|
||||
using DeviceReduceInstance = DeviceReduceBlockWise<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
@@ -135,6 +135,10 @@ class SimpleAppArgs
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
|
||||
"comparing with the host-based reduction"
|
||||
<< std::endl;
|
||||
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
|
||||
"value, 3=decimal value)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg2 -- number of repeats to run the kernel" << std::endl;
|
||||
};
|
||||
|
||||
int processArgs(int argc, char* argv[])
|
||||
@@ -263,20 +267,21 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
switch(args.init_method)
|
||||
{
|
||||
case 0:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_1<InDataType>{}, num_thread);
|
||||
break;
|
||||
case 0: break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{1, 5}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{1, 5}, num_thread);
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
|
||||
}
|
||||
|
||||
if(beta != 0.0f)
|
||||
@@ -293,17 +298,27 @@ int main(int argc, char* argv[])
|
||||
if(beta != 0.0f)
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
|
||||
size_t indicesSizeInBytes = NeedIndices ? out.mDesc.GetElementSize() * sizeof(int) : 0;
|
||||
size_t indicesSizeInBytes = NeedIndices ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0;
|
||||
|
||||
DeviceMem out_indices_dev(indicesSizeInBytes);
|
||||
|
||||
if(args.do_verification)
|
||||
{
|
||||
ReductionHost<InDataType, AccDataType, OutDataType, ReduceOpId, PropagateNan, NeedIndices>
|
||||
ReductionHost<HostInDataType,
|
||||
HostAccDataType,
|
||||
HostOutDataType,
|
||||
ReduceOpId,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
PropagateNan,
|
||||
NeedIndices>
|
||||
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(
|
||||
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
|
||||
hostReduce.Run(alpha,
|
||||
reinterpret_cast<const HostInDataType*>(in.mData.data()),
|
||||
beta,
|
||||
reinterpret_cast<HostOutDataType*>(out_ref.mData.data()),
|
||||
out_indices_ref.mData.data());
|
||||
};
|
||||
|
||||
const auto i_inLengths = to_int_vector(args.inLengths);
|
||||
@@ -313,7 +328,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto reduce = DeviceReduceInstance{};
|
||||
|
||||
auto wsSizeInBytes = reduce.GetWorkspaceSizeInBytes(i_inLengths);
|
||||
auto wsSizeInBytes = reduce.GetWorkspaceSizeInBytes(i_inLengths, reduceDims);
|
||||
|
||||
DeviceMem ws_dev(wsSizeInBytes);
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ cmake \
|
||||
## Run ```pool2d_fwd```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg3: run kernel # of times (>1)
|
||||
#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx
|
||||
./example/pool2d_fwd 1 1 10
|
||||
|
||||
@@ -236,8 +236,9 @@ int main(int argc, char* argv[])
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
|
||||
default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
|
||||
case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}); break;
|
||||
case 2: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
|
||||
default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
|
||||
@@ -16,9 +16,11 @@ namespace device {
|
||||
template <typename InElementwiseOperation, typename AccElementwiseOperation>
|
||||
struct DeviceReduce : public BaseOperator
|
||||
{
|
||||
virtual size_t GetWorkspaceSizeInBytes(const std::vector<int>& inLengths)
|
||||
virtual long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
|
||||
const std::vector<int> reduceDims)
|
||||
{
|
||||
(void)inLengths;
|
||||
(void)reduceDims;
|
||||
|
||||
return (0);
|
||||
};
|
||||
@@ -32,19 +34,19 @@ struct DeviceReduce : public BaseOperator
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) = 0;
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -36,20 +36,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
|
||||
static constexpr index_t numSrcDim = Rank;
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -57,18 +57,18 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDims)
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
@@ -79,6 +79,9 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
@@ -93,18 +96,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
}
|
||||
}();
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen;
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K =
|
||||
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
@@ -112,44 +117,45 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto inPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
Argument(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
@@ -160,21 +166,21 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
alpha_ = type_convert<AccDataType>(alpha);
|
||||
beta_ = type_convert<AccDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
|
||||
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[Rank - 1];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
@@ -186,7 +192,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
@@ -278,18 +284,22 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
|
||||
if(pArg->inStrides_[Rank - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
@@ -308,19 +318,19 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
|
||||
@@ -37,6 +37,10 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
@@ -46,12 +50,8 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -65,18 +65,20 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
const auto in_grid_desc_m_k =
|
||||
make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen;
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K =
|
||||
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
@@ -84,26 +86,27 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto outPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(outerLen, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
@@ -131,8 +134,8 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
in_elementwise_op_(in_elementwise_op),
|
||||
acc_elementwise_op_(acc_elementwise_op)
|
||||
{
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
alpha_ = type_convert<AccDataType>(alpha);
|
||||
beta_ = type_convert<AccDataType>(beta);
|
||||
|
||||
invariant_total_length = inLengths[0];
|
||||
reduce_total_length = inLengths[1];
|
||||
@@ -159,7 +162,7 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
@@ -268,19 +271,19 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
(void)reduceDims;
|
||||
|
||||
|
||||
@@ -12,38 +12,30 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// template <typename preUnaryOpType, typename posUnaryOpType>
|
||||
// using DeviceReducePtr = std::unique_ptr<DeviceReduce<preUnaryOpType, posUnaryOpType>>;
|
||||
|
||||
template <int Rank, typename ReduceDims>
|
||||
// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
|
||||
// of reduce dims
|
||||
template <int Rank, int NumReduceDim>
|
||||
std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
|
||||
{
|
||||
static_assert(Rank <= 6, "bigger Rank size not supported!");
|
||||
|
||||
size_t tensor_total_length = 1;
|
||||
size_t reduce_total_length = 1;
|
||||
size_t invariant_total_length = 1;
|
||||
size_t reduce_total_length = 1;
|
||||
|
||||
static_for<0, ReduceDims::Size(), 1>{}(
|
||||
[&](auto i) { reduce_total_length *= inLengths[ReduceDims::At(i)]; });
|
||||
constexpr int NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static_for<0, Rank, 1>{}([&](auto i) { tensor_total_length *= inLengths[i.value]; });
|
||||
for(int i = NumInvariantDim; i < Rank; i++)
|
||||
reduce_total_length *= inLengths[i];
|
||||
|
||||
return std::make_pair(tensor_total_length / reduce_total_length, reduce_total_length);
|
||||
};
|
||||
for(int i = 0; i < NumInvariantDim; i++)
|
||||
invariant_total_length *= inLengths[i];
|
||||
|
||||
template <int x, typename Seq>
|
||||
constexpr bool belong()
|
||||
{
|
||||
bool inside = false;
|
||||
|
||||
static_for<0, Seq::Size(), 1>{}([&](auto i) { inside = (inside || (x == Seq::At(i))); });
|
||||
|
||||
return (inside);
|
||||
return std::make_pair(invariant_total_length, reduce_total_length);
|
||||
};
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
static auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
|
||||
auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
@@ -59,16 +51,12 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumReduceDim>
|
||||
static inline std::pair<std::vector<int>, std::vector<int>>
|
||||
shuffle_tensor_dimensions(const std::vector<int>& dimLengths,
|
||||
const std::vector<int>& dimStrides,
|
||||
const std::vector<int>& reduceDims)
|
||||
std::vector<int> shuffle_tensor_dimensions(const std::vector<int>& origLengthsStrides,
|
||||
const std::vector<int>& reduceDims)
|
||||
{
|
||||
std::vector<int> newDimLengths;
|
||||
std::vector<int> newDimStrides;
|
||||
std::vector<int> newLengthsStrides;
|
||||
|
||||
assert(Rank == dimLengths.size() && Rank == dimStrides.size() &&
|
||||
NumReduceDim == reduceDims.size());
|
||||
assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
|
||||
|
||||
int reduceFlag = 0;
|
||||
|
||||
@@ -82,19 +70,17 @@ shuffle_tensor_dimensions(const std::vector<int>& dimLengths,
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) == 0)
|
||||
{
|
||||
newDimLengths.push_back(dimLengths[i]);
|
||||
newDimStrides.push_back(dimStrides[i]);
|
||||
newLengthsStrides.push_back(origLengthsStrides[i]);
|
||||
};
|
||||
|
||||
// collect reduce dimensions
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) > 0)
|
||||
{
|
||||
newDimLengths.push_back(dimLengths[i]);
|
||||
newDimStrides.push_back(dimStrides[i]);
|
||||
newLengthsStrides.push_back(origLengthsStrides[i]);
|
||||
};
|
||||
|
||||
return std::make_pair(newDimLengths, newDimStrides);
|
||||
return newLengthsStrides;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -39,18 +39,18 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
|
||||
static constexpr index_t numSrcDim = Rank;
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr bool support_AtomicAdd =
|
||||
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
|
||||
@@ -67,18 +67,18 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
int blkGroupSize,
|
||||
int kBlockTileIterations)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDims)
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
@@ -89,6 +89,9 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
@@ -103,19 +106,20 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
}
|
||||
}();
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen;
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
@@ -123,44 +127,45 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto outPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(outerLen, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
Argument(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
@@ -171,21 +176,21 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
(void)out_indices_dev;
|
||||
(void)workspace_dev;
|
||||
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
alpha_ = type_convert<AccDataType>(alpha);
|
||||
beta_ = type_convert<AccDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
|
||||
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[Rank - 1];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
@@ -218,7 +223,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
@@ -334,18 +339,22 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
|
||||
if(pArg->inStrides_[Rank - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
@@ -371,19 +380,19 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
|
||||
@@ -37,31 +37,35 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
|
||||
static constexpr index_t numSrcDim = Rank;
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
size_t GetWorkspaceSizeInBytes(const std::vector<int>& inLengths) override
|
||||
static constexpr int MaxBlockGroupSize = 256;
|
||||
|
||||
long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
|
||||
const std::vector<int> reduceDims) override
|
||||
{
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
auto inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
@@ -69,8 +73,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
if(testBlkGroupSize <= MaxBlockGroupSize)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
@@ -79,11 +82,12 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
size_t workspace_size = invariant_total_length * blkGroupSize;
|
||||
long_index_t workspace_size = invariant_total_length * blkGroupSize;
|
||||
|
||||
size_t wsSizeInBytes =
|
||||
!NeedIndices ? workspace_size * sizeof(AccDataType)
|
||||
: workspace_size * (sizeof(AccDataType) + sizeof(int)) + 64 + sizeof(int);
|
||||
long_index_t wsSizeInBytes =
|
||||
!NeedIndices
|
||||
? workspace_size * sizeof(AccDataType)
|
||||
: workspace_size * (sizeof(AccDataType) + sizeof(int32_t)) + 64 + sizeof(int);
|
||||
|
||||
return (wsSizeInBytes);
|
||||
};
|
||||
@@ -95,18 +99,18 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
int blkGroupSize,
|
||||
int kBlockTileIterations)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDims)
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
@@ -117,6 +121,9 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
@@ -131,32 +138,35 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
}
|
||||
}();
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen;
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeWorkspace2dDescriptor(int outerLen, int blkGroupSize)
|
||||
static auto MakeWorkspace2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
auto ws_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(outerLen, blkGroupSize));
|
||||
auto ws_desc_m_k =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
|
||||
|
||||
const auto wsPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto wsPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto ws_desc_m_k_padded =
|
||||
transform_tensor_descriptor(ws_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, wsPad),
|
||||
make_tuple(make_right_pad_transform(invariantLength, wsPad),
|
||||
make_pass_through_transform(blkGroupSize)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
@@ -166,19 +176,19 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
Argument(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
@@ -188,21 +198,21 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
alpha_ = type_convert<AccDataType>(alpha);
|
||||
beta_ = type_convert<AccDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
|
||||
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[Rank - 1];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
@@ -210,8 +220,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
if(testBlkGroupSize <= MaxBlockGroupSize)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
@@ -241,7 +250,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
@@ -337,18 +346,22 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
|
||||
if(pArg->inStrides_[Rank - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
@@ -371,19 +384,19 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
|
||||
@@ -36,20 +36,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1),
|
||||
"Threadwise can only be called with KThreadClusterSize be 1 !");
|
||||
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
|
||||
static constexpr index_t numSrcDim = Rank;
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -57,18 +57,18 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDims)
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
@@ -79,6 +79,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
@@ -93,18 +96,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
}
|
||||
}();
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen;
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K =
|
||||
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
@@ -112,44 +117,45 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto outPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(outerLen, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
Argument(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op)
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op)
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
@@ -161,21 +167,21 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
alpha_ = type_convert<AccDataType>(alpha);
|
||||
beta_ = type_convert<AccDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
|
||||
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[Rank - 1];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
@@ -187,7 +193,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
@@ -278,18 +284,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
|
||||
if(pArg->inStrides_[Rank - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
@@ -310,19 +320,19 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op) override
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
|
||||
#define CK_ELEMENT_WISE_OPERATION_HPP
|
||||
#include "data_type.hpp"
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
@@ -19,6 +18,8 @@ struct PassThrough
|
||||
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }
|
||||
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = x; }
|
||||
};
|
||||
|
||||
struct Add
|
||||
@@ -239,6 +240,24 @@ struct UnaryIdentic<int32_t, int32_t, false>
|
||||
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<int32_t, int32_t, true>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x / divider_; };
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<int8_t, int8_t, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int8_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <typename Y, typename X, bool HasDividing = false>
|
||||
struct UnarySquare;
|
||||
|
||||
@@ -311,6 +330,19 @@ struct UnaryAbs<double, double>
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryAbs<int8_t, int8_t>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
int8_t sgn = x >> (8 - 1);
|
||||
|
||||
y = (x ^ sgn) - sgn;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
struct UnarySqrt;
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -52,23 +53,25 @@ __global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
constexpr bool IsSecondCall = false;
|
||||
|
||||
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -102,23 +105,25 @@ kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
constexpr bool IsSecondCall = true;
|
||||
|
||||
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -156,6 +161,11 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
@@ -174,8 +184,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -183,17 +192,24 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
template <bool IsSecondCall>
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(IsSecondCall)
|
||||
{
|
||||
static_assert(InSrcVectorDim == 1,
|
||||
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
|
||||
};
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
@@ -345,7 +361,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -355,7 +371,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -366,7 +382,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
@@ -379,7 +395,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
@@ -570,7 +586,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -580,7 +596,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -591,14 +607,14 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<index_t>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -609,7 +625,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<index_t>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
@@ -631,11 +647,14 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_ws_values_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
static_assert(InSrcVectorDim == 1,
|
||||
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
|
||||
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
@@ -841,7 +860,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -851,7 +870,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -862,14 +881,14 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<IndexDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -880,7 +899,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<index_t>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -84,6 +85,11 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
@@ -109,8 +115,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -249,7 +254,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -260,7 +265,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
out_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
|
||||
@@ -23,8 +23,8 @@
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
@@ -32,6 +32,7 @@
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -101,6 +102,12 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
{
|
||||
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
@@ -119,8 +126,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -238,9 +244,6 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and due to the using of
|
||||
// vector_load, each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
@@ -254,6 +257,9 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_store =
|
||||
@@ -261,7 +267,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
@@ -273,7 +279,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_workspace_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
@@ -450,7 +456,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
@@ -462,14 +468,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_workspace_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<IndexDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
@@ -481,7 +487,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<IndexDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_workspace_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -50,7 +51,7 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
@@ -101,11 +102,15 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_threadwise
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
@@ -115,7 +120,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
@@ -228,7 +233,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -238,7 +243,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -248,7 +253,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
|
||||
@@ -260,7 +265,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
@@ -387,7 +392,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -397,7 +402,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -407,14 +412,14 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<IndexDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -424,7 +429,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<IndexDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
|
||||
|
||||
@@ -79,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
@@ -250,6 +252,8 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -313,7 +317,8 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = src_vector.template AsType<SrcData>()[i];
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
|
||||
});
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
@@ -439,6 +444,10 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
{
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -1016,7 +1025,8 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, "wrong!");
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
}
|
||||
|
||||
template <typename SrcRefToOriginDisplacement,
|
||||
|
||||
@@ -637,19 +637,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
|
||||
@@ -606,6 +606,12 @@ struct sequence_map_inverse
|
||||
SeqMap::Size()>::type;
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr bool operator==(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
return ((Xs == Ys) && ...);
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
|
||||
@@ -37,6 +37,10 @@ struct SpaceFillingCurve
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfAccess()
|
||||
{
|
||||
static_assert(TensorLengths::Size() == ScalarsPerAccess::Size());
|
||||
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
|
||||
typename uniform_sequence_gen<TensorLengths::Size(), 0>::type{});
|
||||
|
||||
return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) /
|
||||
ScalarPerVector;
|
||||
}
|
||||
|
||||
@@ -1,424 +0,0 @@
|
||||
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef HOST_GENERIC_REDUCTION_HPP_
|
||||
#define HOST_GENERIC_REDUCTION_HPP_
|
||||
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
|
||||
using float16 = half_float::half;
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace host_reduce {
|
||||
|
||||
template <typename T>
|
||||
static void
|
||||
get_all_indexes(const std::vector<T>& dimLengths, int dim, std::vector<std::vector<T>>& indexes)
|
||||
{
|
||||
if(dim < dimLengths.size())
|
||||
{
|
||||
std::vector<std::vector<T>> updated_indexes;
|
||||
|
||||
if(dim == 0)
|
||||
{
|
||||
assert(indexes.size() == 0);
|
||||
assert(dimLengths[dim] > 0);
|
||||
for(T i = 0; i < dimLengths[dim]; i++)
|
||||
{
|
||||
std::vector<T> index = {i};
|
||||
|
||||
updated_indexes.push_back(index);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
// go through all the current indexes
|
||||
for(const auto& index : indexes)
|
||||
for(T i = 0; i < dimLengths[dim]; i++)
|
||||
{
|
||||
auto index_new = index;
|
||||
index_new.push_back(i);
|
||||
|
||||
updated_indexes.push_back(index_new);
|
||||
};
|
||||
};
|
||||
|
||||
// update to the indexes (output)
|
||||
indexes = updated_indexes;
|
||||
|
||||
// further to construct the indexes from the updated status
|
||||
get_all_indexes(dimLengths, dim + 1, indexes);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static T get_offset_from_index(const std::vector<T>& strides, const std::vector<T>& index)
|
||||
{
|
||||
T offset = 0;
|
||||
|
||||
assert(strides.size() == index.size());
|
||||
|
||||
for(int i = 0; i < index.size(); i++)
|
||||
offset += strides[i] * static_cast<T>(index[i]);
|
||||
|
||||
return (offset);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline T get_flatten_offset(const std::vector<T>& lengths, const std::vector<T>& index)
|
||||
{
|
||||
T offset = 0;
|
||||
|
||||
assert(lengths.size() == index.size() && lengths.size() > 0);
|
||||
|
||||
int len = lengths.size();
|
||||
T stride = 1;
|
||||
|
||||
// for len==1, the loop is not executed
|
||||
for(int i = len - 1; i > 0; i--)
|
||||
{
|
||||
offset += stride * static_cast<T>(index[i]);
|
||||
|
||||
stride *= lengths[i];
|
||||
};
|
||||
|
||||
offset += stride * static_cast<T>(index[0]);
|
||||
|
||||
return (offset);
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
ck::ReduceTensorOp_t ReduceOpId,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices>
|
||||
class ReductionHost
|
||||
{
|
||||
public:
|
||||
ReductionHost() = default;
|
||||
ReductionHost(HostTensorDescriptor& inDesc,
|
||||
HostTensorDescriptor& outDesc,
|
||||
const std::vector<int>& invariantDims_,
|
||||
const std::vector<int>& toReduceDims_)
|
||||
{
|
||||
this->inLengths = to_int_vector(inDesc.GetLengths());
|
||||
this->outLengths = to_int_vector(outDesc.GetLengths());
|
||||
this->inStrides = to_int_vector(inDesc.GetStrides());
|
||||
this->outStrides = to_int_vector(outDesc.GetStrides());
|
||||
|
||||
this->invariantDims = invariantDims_;
|
||||
this->toReduceDims = toReduceDims_;
|
||||
|
||||
assert(this->inLengths.size() == this->outLengths.size());
|
||||
assert(!this->toReduceDims.empty());
|
||||
|
||||
for(const auto dim : this->invariantDims)
|
||||
this->invariantLengths.push_back(this->inLengths[dim]);
|
||||
|
||||
for(const auto dim : this->toReduceDims)
|
||||
toReduceLengths.push_back(this->inLengths[dim]);
|
||||
|
||||
this->reduceAllDims = this->invariantDims.empty();
|
||||
};
|
||||
|
||||
~ReductionHost(){};
|
||||
|
||||
void
|
||||
Run(float alpha, const InDataType* in_data, float beta, OutDataType* out_data, int* indices)
|
||||
{
|
||||
if constexpr(NeedIndices)
|
||||
RunImpl_with_indices(alpha, in_data, beta, out_data, indices);
|
||||
else
|
||||
RunImpl_no_indices(alpha, in_data, beta, out_data);
|
||||
};
|
||||
|
||||
private:
|
||||
std::vector<int> inLengths;
|
||||
std::vector<int> outLengths;
|
||||
std::vector<int> inStrides;
|
||||
std::vector<int> outStrides;
|
||||
|
||||
std::vector<int> invariantLengths;
|
||||
std::vector<int> toReduceLengths;
|
||||
|
||||
std::vector<int> invariantDims;
|
||||
std::vector<int> toReduceDims;
|
||||
|
||||
bool reduceAllDims;
|
||||
|
||||
void RunImpl_with_indices(
|
||||
float alpha, const InDataType* in_data, float beta, OutDataType* out_data, int* indices)
|
||||
{
|
||||
using ck::host_reduce::binop_with_nan_check;
|
||||
using ck::host_reduce::binop_with_nan_check2;
|
||||
using ck::host_reduce::float_equal_one;
|
||||
using ck::host_reduce::float_equal_zero;
|
||||
using ck::host_reduce::PosUnaryOpFn;
|
||||
using ck::host_reduce::PreUnaryOpFn;
|
||||
using ck::host_reduce::ReduceOpFn2;
|
||||
using ck::host_reduce::ReduceOpZeroVal;
|
||||
|
||||
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>();
|
||||
|
||||
int divider = 1;
|
||||
for(int i = 0; i < toReduceLengths.size(); i++)
|
||||
divider *= toReduceLengths[i];
|
||||
|
||||
auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
|
||||
if(reduceAllDims)
|
||||
{
|
||||
std::vector<std::vector<int>> indexes_1;
|
||||
|
||||
get_all_indexes(inLengths, 0, indexes_1); // generate the input indexes space
|
||||
|
||||
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
int accuIndex = 0;
|
||||
|
||||
// go through indexes of the invariant dimensions
|
||||
for(const auto& src_index : indexes_1)
|
||||
{
|
||||
auto src_offset = get_offset_from_index(this->inStrides, src_index);
|
||||
|
||||
auto currVal = static_cast<AccDataType>(in_data[src_offset]);
|
||||
|
||||
// unary operation before reducing, needed by AMAX. For MIN/MAX, nothing is actually
|
||||
// done
|
||||
PreUnaryOp(currVal);
|
||||
|
||||
auto currIndex = get_flatten_offset(inLengths, src_index);
|
||||
binop_with_nan_check2<AccDataType, PropagateNan>(
|
||||
opReduce, accuVal, currVal, accuIndex, currIndex);
|
||||
};
|
||||
|
||||
// scale the accumulated value
|
||||
if(!float_equal_one(alpha))
|
||||
accuVal *= static_cast<AccDataType>(alpha);
|
||||
|
||||
// scale the prior dst value and add it to the accumulated value
|
||||
if(!float_equal_zero(beta))
|
||||
accuVal += static_cast<AccDataType>(out_data[0]) * static_cast<AccDataType>(beta);
|
||||
|
||||
// store the reduced value to dst location
|
||||
out_data[0] = static_cast<OutDataType>(accuVal);
|
||||
indices[0] = accuIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<std::vector<int>> indexes_1, indexes_2;
|
||||
|
||||
get_all_indexes(
|
||||
this->invariantLengths, 0, indexes_1); // generate the invariant indexes space
|
||||
get_all_indexes(
|
||||
this->toReduceLengths, 0, indexes_2); // generate the toReduce indexes space
|
||||
|
||||
// go through indexes of the invariant dimensions
|
||||
for(const auto& index_1 : indexes_1)
|
||||
{
|
||||
std::vector<int> src_index;
|
||||
std::vector<int> dst_index;
|
||||
|
||||
src_index.resize(this->inLengths.size());
|
||||
|
||||
// generate the part of src index belonging to invariant dims
|
||||
for(int k = 0; k < invariantDims.size(); k++)
|
||||
src_index[invariantDims[k]] = index_1[k];
|
||||
|
||||
for(int k = 0; k < invariantDims.size(); k++)
|
||||
dst_index.push_back(index_1[k]);
|
||||
|
||||
int dst_offset = get_offset_from_index(this->outStrides, dst_index);
|
||||
|
||||
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
int accuIndex = 0;
|
||||
|
||||
// go through indexes of the toReduce dimensions
|
||||
for(const auto& index_2 : indexes_2)
|
||||
{
|
||||
// generate the part of src index belonging to toReduce dims
|
||||
for(int k = 0; k < toReduceDims.size(); k++)
|
||||
src_index[toReduceDims[k]] = index_2[k];
|
||||
|
||||
auto src_offset = get_offset_from_index(this->inStrides, src_index);
|
||||
|
||||
auto currVal = static_cast<AccDataType>(in_data[src_offset]);
|
||||
// unary operation before reducing, needed by AMAX. For MIN/MAX, nothing is
|
||||
// actually done
|
||||
PreUnaryOp(currVal);
|
||||
|
||||
auto currIndex = get_flatten_offset(toReduceLengths, index_2);
|
||||
binop_with_nan_check2<AccDataType, PropagateNan>(
|
||||
opReduce, accuVal, currVal, accuIndex, currIndex);
|
||||
};
|
||||
|
||||
// scale the accumulated value
|
||||
if(!float_equal_one(alpha))
|
||||
accuVal *= static_cast<AccDataType>(alpha);
|
||||
|
||||
// scale the prior dst value and add it to the accumulated value
|
||||
if(!float_equal_zero(beta))
|
||||
accuVal += static_cast<AccDataType>(out_data[dst_offset]) *
|
||||
static_cast<AccDataType>(beta);
|
||||
|
||||
// store the reduced value to dst location
|
||||
out_data[dst_offset] = static_cast<OutDataType>(accuVal);
|
||||
indices[dst_offset] = accuIndex;
|
||||
};
|
||||
};
|
||||
}; // end of RunImpl_with_indices()
|
||||
|
||||
void
|
||||
RunImpl_no_indices(float alpha, const InDataType* in_data, float beta, OutDataType* out_data)
|
||||
{
|
||||
using ck::host_reduce::binop_with_nan_check;
|
||||
using ck::host_reduce::binop_with_nan_check2;
|
||||
using ck::host_reduce::float_equal_one;
|
||||
using ck::host_reduce::float_equal_zero;
|
||||
using ck::host_reduce::PosUnaryOpFn;
|
||||
using ck::host_reduce::PreUnaryOpFn;
|
||||
using ck::host_reduce::ReduceOpFn;
|
||||
using ck::host_reduce::ReduceOpZeroVal;
|
||||
|
||||
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
|
||||
|
||||
int divider = 1;
|
||||
for(int i = 0; i < toReduceLengths.size(); i++)
|
||||
divider *= toReduceLengths[i];
|
||||
|
||||
auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
|
||||
if(reduceAllDims)
|
||||
{
|
||||
std::vector<std::vector<int>> indexes_1;
|
||||
|
||||
get_all_indexes(inLengths, 0, indexes_1); // generate the input indexes space
|
||||
|
||||
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
|
||||
// go through indexes of the invariant dimensions
|
||||
for(const auto& src_index : indexes_1)
|
||||
{
|
||||
auto src_offset = get_offset_from_index(this->inStrides, src_index);
|
||||
|
||||
auto currVal = static_cast<AccDataType>(in_data[src_offset]);
|
||||
|
||||
PreUnaryOp(currVal);
|
||||
|
||||
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
|
||||
};
|
||||
|
||||
PosUnaryOp(accuVal);
|
||||
|
||||
// scale the accumulated value
|
||||
if(!float_equal_one(alpha))
|
||||
accuVal *= static_cast<AccDataType>(alpha);
|
||||
|
||||
// scale the prior dst value and add it to the accumulated value
|
||||
if(!float_equal_zero(beta))
|
||||
accuVal += static_cast<AccDataType>(out_data[0]) * static_cast<AccDataType>(beta);
|
||||
|
||||
// store the reduced value to dst location
|
||||
out_data[0] = static_cast<OutDataType>(accuVal);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<std::vector<int>> indexes_1, indexes_2;
|
||||
|
||||
get_all_indexes(
|
||||
this->invariantLengths, 0, indexes_1); // generate the invariant indexes space
|
||||
get_all_indexes(
|
||||
this->toReduceLengths, 0, indexes_2); // generate the toReduce indexes space
|
||||
|
||||
// go through indexes of the invariant dimensions
|
||||
for(const auto& index_1 : indexes_1)
|
||||
{
|
||||
std::vector<int> src_index;
|
||||
std::vector<int> dst_index;
|
||||
|
||||
src_index.resize(this->inLengths.size());
|
||||
|
||||
for(int k = 0; k < invariantDims.size(); k++)
|
||||
dst_index.push_back(index_1[k]);
|
||||
|
||||
int dst_offset = get_offset_from_index(this->outStrides, dst_index);
|
||||
|
||||
// generate the part of src index belonging to invariant dims
|
||||
for(int k = 0; k < invariantDims.size(); k++)
|
||||
src_index[invariantDims[k]] = index_1[k];
|
||||
|
||||
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
|
||||
// go through indexes of the toReduce dimensions
|
||||
for(const auto& index_2 : indexes_2)
|
||||
{
|
||||
// generate the part of src index belonging to toReduce dims
|
||||
for(int k = 0; k < toReduceDims.size(); k++)
|
||||
src_index[toReduceDims[k]] = index_2[k];
|
||||
|
||||
auto src_offset = get_offset_from_index(this->inStrides, src_index);
|
||||
|
||||
auto currVal = static_cast<AccDataType>(in_data[src_offset]);
|
||||
|
||||
PreUnaryOp(currVal);
|
||||
|
||||
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
|
||||
};
|
||||
|
||||
PosUnaryOp(accuVal);
|
||||
|
||||
// scale the accumulated value
|
||||
if(!float_equal_one(alpha))
|
||||
accuVal *= static_cast<AccDataType>(alpha);
|
||||
|
||||
// scale the prior dst value and add it to the accumulated value
|
||||
if(!float_equal_zero(beta))
|
||||
accuVal += static_cast<AccDataType>(out_data[dst_offset]) *
|
||||
static_cast<AccDataType>(beta);
|
||||
|
||||
// store the reduced value to dst location
|
||||
out_data[dst_offset] = static_cast<OutDataType>(accuVal);
|
||||
};
|
||||
};
|
||||
}; // end of RunImpl_no_indices()
|
||||
};
|
||||
|
||||
}; // end of namespace host_reduce
|
||||
|
||||
}; // end of namespace ck
|
||||
|
||||
#endif
|
||||
@@ -66,22 +66,22 @@ static inline bool float_equal_zero(half_float::half x)
|
||||
return x == static_cast<half_float::half>(0.0f);
|
||||
};
|
||||
|
||||
template <typename compType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline std::function<void(compType&)> PreUnaryOpFn(int)
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline std::function<void(AccDataType&)> PreUnaryOpFn(int)
|
||||
{
|
||||
using std::abs;
|
||||
|
||||
if constexpr(ReduceOpId == ReduceTensorOp_t::NORM1)
|
||||
{
|
||||
return ([&](compType& a_) { a_ = abs(a_); });
|
||||
return ([&](AccDataType& a_) { a_ = abs(a_); });
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::NORM2)
|
||||
{
|
||||
return ([&](compType& a_) { a_ = a_ * a_; });
|
||||
return ([&](AccDataType& a_) { a_ = a_ * a_; });
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::AMAX)
|
||||
{
|
||||
return ([&](compType& a_) { a_ = abs(a_); });
|
||||
return ([&](AccDataType& a_) { a_ = abs(a_); });
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -90,23 +90,23 @@ __host__ static inline std::function<void(compType&)> PreUnaryOpFn(int)
|
||||
// ReduceTensorOp_t::MUL:
|
||||
// ReduceTensorOp_t::MIN:
|
||||
// ReduceTensorOp_t::MAX:
|
||||
return ([&](compType&) {});
|
||||
return ([&](AccDataType&) {});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename compType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline std::function<void(compType&)> PosUnaryOpFn(int divider)
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t divider)
|
||||
{
|
||||
using std::sqrt;
|
||||
|
||||
if constexpr(ReduceOpId == ReduceTensorOp_t::NORM2)
|
||||
{
|
||||
return ([&](compType& a_) { a_ = sqrt(a_); });
|
||||
return ([&](AccDataType& a_) { a_ = sqrt(a_); });
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::AVG)
|
||||
{
|
||||
return ([&, divider](compType& a_) {
|
||||
a_ = a_ / static_cast<compType>(static_cast<float>(divider));
|
||||
return ([&, divider](AccDataType& a_) {
|
||||
a_ = a_ / static_cast<AccDataType>(static_cast<float>(divider));
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -117,44 +117,44 @@ __host__ static inline std::function<void(compType&)> PosUnaryOpFn(int divider)
|
||||
// ReduceTensorOp_t::MIN:
|
||||
// ReduceTensorOp_t::MAX:
|
||||
// ReduceTensorOp_t::AMAX:
|
||||
return ([&](compType&) {});
|
||||
return ([&](AccDataType&) {});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename compType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline std::function<void(compType&, compType)> ReduceOpFn()
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn()
|
||||
{
|
||||
if constexpr(ReduceOpId == ReduceTensorOp_t::ADD || ReduceOpId == ReduceTensorOp_t::AVG ||
|
||||
ReduceOpId == ReduceTensorOp_t::NORM1 || ReduceOpId == ReduceTensorOp_t::NORM2)
|
||||
{
|
||||
return ([&](compType& a_, compType b_) { a_ = a_ + b_; });
|
||||
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; });
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::MUL)
|
||||
{
|
||||
return ([&](compType& a_, compType b_) { a_ = a_ * b_; });
|
||||
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; });
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::MIN)
|
||||
{
|
||||
return ([&](compType& a_, compType b_) {
|
||||
return ([&](AccDataType& a_, AccDataType b_) {
|
||||
if(a_ > b_)
|
||||
a_ = b_;
|
||||
});
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::MAX || ReduceOpId == ReduceTensorOp_t::AMAX)
|
||||
{
|
||||
return ([&](compType& a_, compType b_) {
|
||||
return ([&](AccDataType& a_, AccDataType b_) {
|
||||
if(a_ < b_)
|
||||
a_ = b_;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename compType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline std::function<void(compType&, compType, bool& changed)> ReduceOpFn2()
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline std::function<void(AccDataType&, AccDataType, bool& changed)> ReduceOpFn2()
|
||||
{
|
||||
if constexpr(ReduceOpId == ReduceTensorOp_t::MIN)
|
||||
{
|
||||
return ([&](compType& a_, compType b_, bool& changed) {
|
||||
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
|
||||
if(a_ > b_)
|
||||
{
|
||||
a_ = b_;
|
||||
@@ -166,7 +166,7 @@ __host__ static inline std::function<void(compType&, compType, bool& changed)> R
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::MAX || ReduceOpId == ReduceTensorOp_t::AMAX)
|
||||
{
|
||||
return ([&](compType& a_, compType b_, bool& changed) {
|
||||
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
|
||||
if(a_ < b_)
|
||||
{
|
||||
a_ = b_;
|
||||
@@ -183,28 +183,28 @@ __host__ static inline std::function<void(compType&, compType, bool& changed)> R
|
||||
// ReduceTensorOp_t::AVG:
|
||||
// ReduceTensorOp_t::NORM1:
|
||||
// ReduceTensorOp_t::NORM2:
|
||||
return (std::function<void(compType&, compType, bool&)>{});
|
||||
return (std::function<void(AccDataType&, AccDataType, bool&)>{});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename compType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline compType ReduceOpZeroVal()
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
|
||||
__host__ static inline AccDataType ReduceOpZeroVal()
|
||||
{
|
||||
if constexpr(ReduceOpId == ReduceTensorOp_t::MUL)
|
||||
{
|
||||
return (static_cast<compType>(1.0f));
|
||||
return (static_cast<AccDataType>(1.0f));
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::MIN)
|
||||
{
|
||||
return (std::numeric_limits<compType>::max());
|
||||
return (std::numeric_limits<AccDataType>::max());
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::MAX)
|
||||
{
|
||||
return (std::numeric_limits<compType>::lowest());
|
||||
return (std::numeric_limits<AccDataType>::lowest());
|
||||
}
|
||||
else if constexpr(ReduceOpId == ReduceTensorOp_t::AMAX)
|
||||
{
|
||||
return (static_cast<compType>(0.0f));
|
||||
return (static_cast<AccDataType>(0.0f));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -212,14 +212,15 @@ __host__ static inline compType ReduceOpZeroVal()
|
||||
// ReduceTensorOp_t::AVG
|
||||
// ReduceTensorOp_t::NORM1
|
||||
// ReduceTensorOp_t::NORM2
|
||||
return (static_cast<compType>(0.0f));
|
||||
return (static_cast<AccDataType>(0.0f));
|
||||
};
|
||||
};
|
||||
|
||||
template <typename compType, bool PropagateNan>
|
||||
__host__ static inline void binop_with_nan_check(std::function<void(compType&, compType)> opReduce,
|
||||
compType& accuVal,
|
||||
compType currVal)
|
||||
template <typename AccDataType, bool PropagateNan>
|
||||
__host__ static inline void
|
||||
binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
|
||||
AccDataType& accuVal,
|
||||
AccDataType currVal)
|
||||
{
|
||||
using std::isnan;
|
||||
|
||||
@@ -236,11 +237,11 @@ __host__ static inline void binop_with_nan_check(std::function<void(compType&, c
|
||||
};
|
||||
};
|
||||
|
||||
template <typename compType, bool PropagateNan>
|
||||
template <typename AccDataType, bool PropagateNan>
|
||||
__host__ static inline void
|
||||
binop_with_nan_check2(std::function<void(compType&, compType, bool&)> opReduce,
|
||||
compType& accuVal,
|
||||
compType currVal,
|
||||
binop_with_nan_check2(std::function<void(AccDataType&, AccDataType, bool&)> opReduce,
|
||||
AccDataType& accuVal,
|
||||
AccDataType currVal,
|
||||
int& accuIndex,
|
||||
int currIndex)
|
||||
{
|
||||
|
||||
402
library/include/ck/library/host_tensor/host_reduction.hpp
Normal file
402
library/include/ck/library/host_tensor/host_reduction.hpp
Normal file
@@ -0,0 +1,402 @@
|
||||
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef HOST_REDUCTION_HPP_
|
||||
#define HOST_REDUCTION_HPP_
|
||||
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <functional>
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "data_type.hpp"
|
||||
|
||||
template <int NDim>
|
||||
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
|
||||
std::vector<std::array<size_t, NDim>>& indexes)
|
||||
{
|
||||
static_assert(NDim >= 1, "NDim >= 1 is required to use this function!");
|
||||
|
||||
if constexpr(NDim == 1)
|
||||
{
|
||||
for(size_t i = 0; i < dimLengths[0]; i++)
|
||||
{
|
||||
std::array<size_t, 1> index{i};
|
||||
|
||||
indexes.push_back(index);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::array<size_t, NDim - 1> partial_dim_lengths;
|
||||
|
||||
for(int i = 0; i < NDim - 1; i++)
|
||||
partial_dim_lengths[i] = dimLengths[i + 1];
|
||||
|
||||
std::vector<std::array<size_t, NDim - 1>> partial_indexes;
|
||||
|
||||
get_all_indexes<NDim - 1>(partial_dim_lengths, partial_indexes);
|
||||
|
||||
for(size_t i = 0; i < dimLengths[0]; i++)
|
||||
for(const auto& index : partial_indexes)
|
||||
{
|
||||
std::array<size_t, NDim> extIndex;
|
||||
|
||||
extIndex[0] = i;
|
||||
|
||||
for(int k = 0; k < NDim - 1; k++)
|
||||
extIndex[k + 1] = index[k];
|
||||
|
||||
indexes.push_back(extIndex);
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
template <int NDim>
|
||||
static size_t get_offset_from_index(const std::array<size_t, NDim>& strides,
|
||||
const std::array<size_t, NDim>& index)
|
||||
{
|
||||
size_t offset = 0;
|
||||
|
||||
for(int i = 0; i < NDim; i++)
|
||||
offset += strides[i] * index[i];
|
||||
|
||||
return (offset);
|
||||
};
|
||||
|
||||
template <int NDim>
|
||||
static size_t get_offset_from_index(const std::vector<size_t>& strides,
|
||||
const std::array<size_t, NDim>& index)
|
||||
{
|
||||
size_t offset = 0;
|
||||
|
||||
for(int i = 0; i < NDim; i++)
|
||||
offset += strides[i] * index[i];
|
||||
|
||||
return (offset);
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
ck::ReduceTensorOp_t ReduceOpId,
|
||||
int Rank,
|
||||
int NumReduceDim,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices>
|
||||
struct ReductionHost
|
||||
{
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr int NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
std::vector<size_t> outStrides;
|
||||
std::vector<int> invariantDims;
|
||||
std::vector<int> reduceDims;
|
||||
|
||||
IndexDataType divider;
|
||||
std::function<void(AccDataType&)> preUnaryOp;
|
||||
std::function<void(AccDataType&)> posUnaryOp;
|
||||
std::array<size_t, NumReduceDim> reduceLengths;
|
||||
std::array<size_t, NumReduceDim> reduceStrides;
|
||||
std::array<size_t, NumInvariantDim> invariantLengths;
|
||||
std::array<size_t, NumInvariantDim> invariantStrides;
|
||||
|
||||
std::vector<std::array<size_t, NumReduceDim>> reduce_dim_indexes;
|
||||
std::vector<std::array<size_t, NumInvariantDim>> invariant_dim_indexes;
|
||||
|
||||
ReductionHost(HostTensorDescriptor& inDesc,
|
||||
HostTensorDescriptor& outDesc,
|
||||
const std::vector<int>& invariantDims_,
|
||||
const std::vector<int>& reduceDims_)
|
||||
{
|
||||
using ck::host_reduce::PosUnaryOpFn;
|
||||
using ck::host_reduce::PreUnaryOpFn;
|
||||
|
||||
// this->outLengths = to_int_vector(outDesc.GetLengths());
|
||||
this->outStrides = outDesc.GetStrides();
|
||||
|
||||
this->invariantDims = invariantDims_;
|
||||
this->reduceDims = reduceDims_;
|
||||
|
||||
int product = 1;
|
||||
|
||||
for(int i = 0; i < NumReduceDim; i++)
|
||||
{
|
||||
reduceLengths[i] = inDesc.GetLengths()[reduceDims[i]];
|
||||
reduceStrides[i] = inDesc.GetStrides()[reduceDims[i]];
|
||||
product *= inDesc.GetLengths()[reduceDims[i]];
|
||||
};
|
||||
|
||||
divider = product;
|
||||
|
||||
for(int i = 0; i < NumInvariantDim; i++)
|
||||
{
|
||||
invariantLengths[i] = inDesc.GetLengths()[invariantDims[i]];
|
||||
invariantStrides[i] = inDesc.GetStrides()[invariantDims[i]];
|
||||
};
|
||||
|
||||
reduce_dim_indexes.clear();
|
||||
get_all_indexes<NumReduceDim>(reduceLengths, reduce_dim_indexes);
|
||||
|
||||
if constexpr(NumInvariantDim > 0)
|
||||
{
|
||||
invariant_dim_indexes.clear();
|
||||
get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes);
|
||||
};
|
||||
|
||||
preUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
posUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
};
|
||||
|
||||
void Run(float alpha,
|
||||
const InDataType* in_data,
|
||||
float beta,
|
||||
OutDataType* out_data,
|
||||
IndexDataType* out_indices)
|
||||
{
|
||||
if constexpr(NeedIndices)
|
||||
{
|
||||
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunImpl_no_index(alpha, in_data, beta, out_data);
|
||||
};
|
||||
};
|
||||
|
||||
void RunImpl_with_index(float alpha,
|
||||
const InDataType* in_data,
|
||||
float beta,
|
||||
OutDataType* out_data,
|
||||
IndexDataType* out_indices)
|
||||
{
|
||||
using ck::type_convert;
|
||||
using ck::host_reduce::binop_with_nan_check2;
|
||||
using ck::host_reduce::float_equal_one;
|
||||
using ck::host_reduce::float_equal_zero;
|
||||
using ck::host_reduce::ReduceOpFn2;
|
||||
using ck::host_reduce::ReduceOpZeroVal;
|
||||
|
||||
auto opReduce2 = ReduceOpFn2<AccDataType, ReduceOpId>();
|
||||
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
IndexDataType accuIndex = 0;
|
||||
|
||||
for(IndexDataType i = 0; i < reduce_dim_indexes.size(); i++)
|
||||
{
|
||||
auto offset_reduce =
|
||||
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
|
||||
|
||||
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
|
||||
|
||||
preUnaryOp(currVal);
|
||||
|
||||
auto currIndex = i;
|
||||
|
||||
binop_with_nan_check2<AccDataType, PropagateNan>(
|
||||
opReduce2, accuVal, currVal, accuIndex, currIndex);
|
||||
};
|
||||
|
||||
posUnaryOp(accuVal);
|
||||
|
||||
if(!float_equal_one(alpha))
|
||||
accuVal *= type_convert<AccDataType>(alpha);
|
||||
|
||||
if(!float_equal_zero(beta))
|
||||
accuVal += type_convert<AccDataType>(out_data[0]) * type_convert<AccDataType>(beta);
|
||||
|
||||
out_data[0] = type_convert<OutDataType>(accuVal);
|
||||
out_indices[0] = accuIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto thread_reduce_func = [&](auto invariant_index) {
|
||||
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
IndexDataType accuIndex = 0;
|
||||
|
||||
auto offset_invariant =
|
||||
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
|
||||
|
||||
for(IndexDataType i = 0; i < reduce_dim_indexes.size(); i++)
|
||||
{
|
||||
auto offset_reduce =
|
||||
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
|
||||
|
||||
auto currVal =
|
||||
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
|
||||
|
||||
preUnaryOp(currVal);
|
||||
|
||||
auto currIndex = i;
|
||||
|
||||
binop_with_nan_check2<AccDataType, PropagateNan>(
|
||||
opReduce2, accuVal, currVal, accuIndex, currIndex);
|
||||
};
|
||||
|
||||
posUnaryOp(accuVal);
|
||||
|
||||
if(!float_equal_one(alpha))
|
||||
accuVal *= type_convert<AccDataType>(alpha);
|
||||
|
||||
auto dst_offset =
|
||||
get_offset_from_index<NumInvariantDim>(outStrides, invariant_index);
|
||||
|
||||
if(!float_equal_zero(beta))
|
||||
accuVal += type_convert<AccDataType>(out_data[dst_offset]) *
|
||||
type_convert<AccDataType>(beta);
|
||||
|
||||
out_data[dst_offset] = type_convert<OutDataType>(accuVal);
|
||||
out_indices[dst_offset] = accuIndex;
|
||||
};
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
std::size_t work_per_thread =
|
||||
(invariant_dim_indexes.size() + num_thread - 1) / num_thread;
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t iw_begin = it * work_per_thread;
|
||||
std::size_t iw_end =
|
||||
std::min((it + 1) * work_per_thread, invariant_dim_indexes.size());
|
||||
|
||||
auto f = [=] {
|
||||
for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
|
||||
{
|
||||
thread_reduce_func(invariant_dim_indexes[iw]);
|
||||
}
|
||||
};
|
||||
|
||||
threads[it] = joinable_thread(f);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data)
|
||||
{
|
||||
using ck::type_convert;
|
||||
using ck::host_reduce::binop_with_nan_check;
|
||||
using ck::host_reduce::float_equal_one;
|
||||
using ck::host_reduce::float_equal_zero;
|
||||
using ck::host_reduce::ReduceOpFn;
|
||||
using ck::host_reduce::ReduceOpZeroVal;
|
||||
|
||||
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
|
||||
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
|
||||
for(const auto& reduce_index : reduce_dim_indexes)
|
||||
{
|
||||
auto offset_reduce =
|
||||
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_index);
|
||||
|
||||
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
|
||||
|
||||
preUnaryOp(currVal);
|
||||
|
||||
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
|
||||
};
|
||||
|
||||
posUnaryOp(accuVal);
|
||||
|
||||
if(!float_equal_one(alpha))
|
||||
accuVal *= type_convert<AccDataType>(alpha);
|
||||
|
||||
if(!float_equal_zero(beta))
|
||||
accuVal += type_convert<AccDataType>(out_data[0]) * type_convert<AccDataType>(beta);
|
||||
|
||||
out_data[0] = type_convert<OutDataType>(accuVal);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto thread_reduce_func = [&](auto invariant_index) {
|
||||
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
|
||||
auto offset_invariant =
|
||||
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
|
||||
|
||||
for(const auto& reduce_index : reduce_dim_indexes)
|
||||
{
|
||||
auto offset_reduce =
|
||||
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_index);
|
||||
|
||||
auto currVal =
|
||||
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
|
||||
|
||||
preUnaryOp(currVal);
|
||||
|
||||
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
|
||||
};
|
||||
|
||||
posUnaryOp(accuVal);
|
||||
|
||||
if(!float_equal_one(alpha))
|
||||
accuVal *= type_convert<AccDataType>(alpha);
|
||||
|
||||
auto dst_offset =
|
||||
get_offset_from_index<NumInvariantDim>(outStrides, invariant_index);
|
||||
|
||||
if(!float_equal_zero(beta))
|
||||
accuVal += type_convert<AccDataType>(out_data[dst_offset]) *
|
||||
type_convert<AccDataType>(beta);
|
||||
|
||||
out_data[dst_offset] = type_convert<OutDataType>(accuVal);
|
||||
};
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
std::size_t work_per_thread =
|
||||
(invariant_dim_indexes.size() + num_thread - 1) / num_thread;
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t iw_begin = it * work_per_thread;
|
||||
std::size_t iw_end =
|
||||
std::min((it + 1) * work_per_thread, invariant_dim_indexes.size());
|
||||
|
||||
auto f = [=] {
|
||||
for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
|
||||
{
|
||||
thread_reduce_func(invariant_dim_indexes[iw]);
|
||||
}
|
||||
};
|
||||
|
||||
threads[it] = joinable_thread(f);
|
||||
}
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -6,23 +6,36 @@
|
||||
#include "device_reduce_instance_blockwise_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
|
||||
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
|
||||
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp"
|
||||
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
|
||||
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
|
||||
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_threadwise_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_threadwise_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_threadwise_i8_i8_i8.hpp"
|
||||
#include "device_reduce_instance_threadwise_i8_i32_i8.hpp"
|
||||
#include "device_reduce_instance_threadwise_b16_f32_b16.hpp"
|
||||
|
||||
#endif
|
||||
|
||||
@@ -17,7 +17,6 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
|
||||
ReductionConfiguration_2<0, 2, 2, 2, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 2, 1>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
ReductionConfiguration_2<1, 2, 2, 1, 2>,
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -13,21 +13,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,12 +13,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,30 +13,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,12 +13,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,30 +13,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,47 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -15,9 +15,7 @@ using reduce_configuration_2_instances_blockwise_second_call = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
ReductionConfiguration_2<1, 2, 2, 1, 2>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>,
|
||||
ReductionConfiguration_2<1, 1, 2, 1, 3>
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>
|
||||
// clang-format on
|
||||
>;
|
||||
#else
|
||||
|
||||
@@ -13,21 +13,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_B16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_B16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -13,12 +13,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,30 +13,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,12 +13,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,30 +13,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I32_I32_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I32_I32_I8_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,47 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I8_I8_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I8_I8_I8_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -17,7 +17,6 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
|
||||
ReductionConfiguration_2<0, 2, 2, 2, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 2, 1>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
ReductionConfiguration_2<1, 2, 2, 1, 2>,
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -13,9 +13,11 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,9 +13,11 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,9 +13,11 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_B16_F32_B16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_B16_F32_B16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -13,21 +13,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,12 +13,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,25 +13,32 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,6 +13,7 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,33 +13,42 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
|
||||
// Will be moved to use MultiBlockAtomicAdd
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_I8_I32_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_I8_I32_I8_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,47 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_I8_I8_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_I8_I8_I8_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -17,7 +17,6 @@ using reduce_configuration_2_instances_threadwise = std::tuple<
|
||||
ReductionConfiguration_2<0, 2, 2, 2, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 2, 1>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
ReductionConfiguration_2<1, 2, 2, 1, 2>,
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -13,21 +13,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,12 +13,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,30 +13,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,12 +13,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -13,30 +13,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,47 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -5,24 +5,37 @@ set(DEVICE_REDUCE_INSTANCE_SOURCE
|
||||
device_reduce_instance_blockwise_f32_f32_f32.cpp;
|
||||
device_reduce_instance_blockwise_f32_f64_f32.cpp;
|
||||
device_reduce_instance_blockwise_f64_f64_f64.cpp;
|
||||
device_reduce_instance_blockwise_i8_i32_i8.cpp;
|
||||
device_reduce_instance_blockwise_i8_i8_i8.cpp;
|
||||
device_reduce_instance_blockwise_b16_f32_b16.cpp;
|
||||
device_reduce_instance_threadwise_f16_f16_f16.cpp;
|
||||
device_reduce_instance_threadwise_f16_f32_f16.cpp;
|
||||
device_reduce_instance_threadwise_f32_f32_f32.cpp;
|
||||
device_reduce_instance_threadwise_f32_f64_f32.cpp;
|
||||
device_reduce_instance_threadwise_f64_f64_f64.cpp;
|
||||
device_reduce_instance_threadwise_i8_i32_i8.cpp;
|
||||
device_reduce_instance_threadwise_i8_i8_i8.cpp;
|
||||
device_reduce_instance_threadwise_b16_f32_b16.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp;
|
||||
device_reduce_instance_blockwise_second_call_i32_i32_i8.cpp;
|
||||
device_reduce_instance_blockwise_second_call_i8_i8_i8.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f32_f32_b16.cpp;
|
||||
device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp;
|
||||
device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp;
|
||||
device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp;
|
||||
device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.cpp;
|
||||
)
|
||||
|
||||
add_library(device_reduce_instance SHARED ${DEVICE_REDUCE_INSTANCE_SOURCE})
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,21 +8,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,21 +8,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,9 +8,11 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,9 +8,11 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,9 +8,11 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,21 +8,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,25 +8,32 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,6 +8,7 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,33 +8,42 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
|
||||
// Will be moved to use MultiBlockAtomicAdd
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,53 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,21 +8,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user