mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Reduction for int8 and bfloat16 (#125)
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Add support for int8_t reduction (ADD/AVG, MIN/MAX/AMAX) * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny fix in testing script profile_reduce_no_index.sh * Add support for bfp16 reduction (using bhalf_t = ushort) * Tiny fix in amd_buffer_addressing.hpp * Tiny change in script/profile_reduce_with_index.sh * Use AccDataType for Beta value and use element_wise::PassThrough * Use type_convert for type converting in host layer reduction * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim * Fix the leaked type_convert in ThreadwiseTensorSliceTransfer_v2 * Update to testing scripts to add bf16 support * added more static_assert * Remove buggy tunable configurations defined in device_reduce_instance_xxx.hpp * Add static_assert to give compile-time warning for incorrect thread slice-size/vector-size configurations * minor change * Refine and fix (in GetWorkspaceSizeInBytes of MultiBlockPartialReduce) to make int8 completely pass * Tiny renaming in gridwise_2d_reduction_multiblock_partial_reduce.hpp * Tiny fix in script/profile_reduce_no_index.sh * Refine in DeviceReduce layer with regard to using NumInvariantDim/NumReduceDim or InvariantDims/ReduceDims * Generic renaming in host reduction and DeviceReduce layer * Add support for 4-d all dimension reduction in the profiler and add_device_reduce_xxx instances * Use multi-thread and simplification for host Reduction implementation * Add ctest for reduction * Update to clarify the using of data init method in produce_reduce/example_reduce/test_reduce/ * Update to the reduce CTest executables to enable default testing behavior when no command argument * Renaming Co-authored-by: Jianfeng yan <jfyan008@gmail.com>
This commit is contained in:
@@ -5,24 +5,37 @@ set(DEVICE_REDUCE_INSTANCE_SOURCE
|
||||
device_reduce_instance_blockwise_f32_f32_f32.cpp;
|
||||
device_reduce_instance_blockwise_f32_f64_f32.cpp;
|
||||
device_reduce_instance_blockwise_f64_f64_f64.cpp;
|
||||
device_reduce_instance_blockwise_i8_i32_i8.cpp;
|
||||
device_reduce_instance_blockwise_i8_i8_i8.cpp;
|
||||
device_reduce_instance_blockwise_b16_f32_b16.cpp;
|
||||
device_reduce_instance_threadwise_f16_f16_f16.cpp;
|
||||
device_reduce_instance_threadwise_f16_f32_f16.cpp;
|
||||
device_reduce_instance_threadwise_f32_f32_f32.cpp;
|
||||
device_reduce_instance_threadwise_f32_f64_f32.cpp;
|
||||
device_reduce_instance_threadwise_f64_f64_f64.cpp;
|
||||
device_reduce_instance_threadwise_i8_i32_i8.cpp;
|
||||
device_reduce_instance_threadwise_i8_i8_i8.cpp;
|
||||
device_reduce_instance_threadwise_b16_f32_b16.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp;
|
||||
device_reduce_instance_blockwise_second_call_i32_i32_i8.cpp;
|
||||
device_reduce_instance_blockwise_second_call_i8_i8_i8.cpp;
|
||||
device_reduce_instance_blockwise_second_call_f32_f32_b16.cpp;
|
||||
device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp;
|
||||
device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp;
|
||||
device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp;
|
||||
device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.cpp;
|
||||
device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.cpp;
|
||||
)
|
||||
|
||||
add_library(device_reduce_instance SHARED ${DEVICE_REDUCE_INSTANCE_SOURCE})
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,21 +8,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,21 +8,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,9 +8,11 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,9 +8,11 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,9 +8,11 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,21 +8,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,25 +8,32 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,6 +8,7 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,33 +8,42 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
|
||||
// Will be moved to use MultiBlockAtomicAdd
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,53 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
|
||||
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -8,21 +8,27 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,12 +8,15 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -8,30 +8,39 @@ namespace device_reduce_instance {
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user