mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Overhaul to Reducton and its dependants (#237)
* Tiny fix in dynamic_buffer.hpp to support vectorized AtomicAdd for double type
* Update to host layer and host reduction
* Merge and remove reduction kernels
* Merge and remove reduction device interfaces and update pooling device interface
* Merge and remove useless reduction device instances
* Update to reduction profiler and reduction ctests
* Update to reduction and pooling examples and add one reduction example
* Change to reduction examples to let them testable by ctest
* Add explicit pass checking for reduction and pooling examples
* Explicit assignment of tensor shapes in example reduce_blockwise_two_call
* Use atomic_add to repace atomicAdd and add atomic_add for double type
* Add reduce ctest support for double data type
* Replace to_int_vector() by using c++ std::vector::assign()
* Keep DeviceReduceThreadWise separated from DeviceReduceBlockWise
* Merge DeviceReduceBlockWise and DeviceReduceMultiBlockAtomicAdd into DeviceReduceMultiBlock
* Add GetAtomicOperationZeroValue() support for AtomicMax
* Tiny change to reduce example README.md
* Fix some tiny issues due to branch merging
* Revoke previous change in dynamic_buffer.hpp and add atomic_add for double2_t
* Add reduce multiblock_atomic_add instances for fp64 to verify vectorized atomic_add on fp64
* Renaming
* Clean the header includings in device_reduce instances header files
[ROCm/composable_kernel commit: 63eee2d999]
This commit is contained in:
@@ -325,7 +325,7 @@ struct DynamicBuffer
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
|
||||
atomic_add(c_style_pointer_cast<X*>(&p_data_[i]), x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,12 @@ __device__ float atomic_add<float>(float* p_dst, const float& x)
|
||||
return atomicAdd(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ double atomic_add<double>(double* p_dst, const double& x)
|
||||
{
|
||||
return atomicAdd(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
|
||||
{
|
||||
@@ -45,6 +51,23 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
|
||||
return vy.template AsType<float2_t>()[I0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
const vector_type<double, 2> vx{x};
|
||||
vector_type<double, 2> vy{0};
|
||||
|
||||
vy.template AsType<double>()(I0) =
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst), vx.template AsType<double>()[I0]);
|
||||
vy.template AsType<double>()(I1) =
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, vx.template AsType<double>()[I1]);
|
||||
|
||||
return vy.template AsType<double2_t>()[I0];
|
||||
}
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
|
||||
|
||||
@@ -26,7 +26,8 @@
|
||||
#ifndef CK_REDUCTION_OPERATOR_HPP
|
||||
#define CK_REDUCTION_OPERATOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "config.hpp"
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -41,12 +42,10 @@ namespace reduce {
|
||||
// when operated against them, and the concept is similar to zero vector in
|
||||
// vector space
|
||||
// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
|
||||
// 2) indexable -- boolean value indicating whether indices of the operated elements could be
|
||||
// recorded. Usually, Min/Max operator could
|
||||
// need to record the indices of elements. For operator like Add/Mul, no need to
|
||||
// record the indices.
|
||||
// 3) operator() -- the first argument of the operator must be both an input & output, and the
|
||||
// corresponding variable usually stores
|
||||
// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
|
||||
// operator can use the InMemoryDataOperation to finalize, or else it return false 3) operator() --
|
||||
// the first argument of the operator must be both an input & output, and the corresponding variable
|
||||
// usually stores
|
||||
// the accumulated result of many operator() calls; the second argument is only an
|
||||
// input. For indexable binary
|
||||
// operator, the second version of operator() has third argument (which is an
|
||||
@@ -62,6 +61,13 @@ struct Add
|
||||
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
|
||||
__device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
{
|
||||
return operation == InMemoryDataOperationEnum::AtomicAdd ||
|
||||
operation == InMemoryDataOperationEnum::Set;
|
||||
};
|
||||
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
|
||||
};
|
||||
|
||||
@@ -72,6 +78,12 @@ struct Mul
|
||||
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
|
||||
|
||||
__device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
{
|
||||
return operation == InMemoryDataOperationEnum::Set;
|
||||
};
|
||||
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
|
||||
};
|
||||
|
||||
@@ -85,6 +97,13 @@ struct Max
|
||||
return NumericLimits<T>::Lowest();
|
||||
};
|
||||
|
||||
__device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
{
|
||||
// ToChange: atomic_max to be added
|
||||
return operation == InMemoryDataOperationEnum::Set;
|
||||
};
|
||||
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
if(a < b)
|
||||
@@ -111,6 +130,13 @@ struct Min
|
||||
return NumericLimits<T>::Max();
|
||||
};
|
||||
|
||||
__device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
{
|
||||
// ToChange: atomic_min to be added
|
||||
return operation == InMemoryDataOperationEnum::Set;
|
||||
};
|
||||
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
if(a > b)
|
||||
@@ -134,6 +160,13 @@ struct AMax
|
||||
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
|
||||
__device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
{
|
||||
// ToChange: atomic_max to be added
|
||||
return operation == InMemoryDataOperationEnum::Set;
|
||||
};
|
||||
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
if(a < b)
|
||||
@@ -150,6 +183,17 @@ struct AMax
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
{
|
||||
T result = ck::type_convert<T>(0.0f);
|
||||
|
||||
if(operation == InMemoryDataOperationEnum::AtomicMax)
|
||||
result = ck::NumericLimits<T>::Lowest();
|
||||
|
||||
return (result);
|
||||
};
|
||||
|
||||
}; // end of namespace reduce
|
||||
|
||||
} // end of namespace ck
|
||||
|
||||
Reference in New Issue
Block a user