mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Update to the Reduction API and instances (#476)
* Simplify the macros for declaring and defining the add_device_reduce_instance_xxxx() instances * Change the types of lengths and strides from std::vector to std::array for the reduction device interfaces * Remove DeviceSoftmaxImpl's depending on DeviceReduceMultiblock * Split the cpp and hpp files for reduction instances to enable more parallel compiling * Remove the using of macros for declaring reduction instances and instance references * Update to add_device_reduce_instance_xxxx templated functions * Use ReduceOperation+InElementwiseOp+AccElementwiseOp to repace the ReduceOpId in defining add_reduce_instance_xxxx() templates * Change return format
This commit is contained in:
@@ -5,11 +5,10 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
template <ck::index_t Rank, ck::index_t NumReduceDim>
|
||||
std::vector<int> get_invariant_dims(const std::vector<int>& reduceDims)
|
||||
template <int Rank, int NumReduceDim>
|
||||
static inline std::array<int, Rank - NumReduceDim>
|
||||
get_invariant_dims(const std::array<int, NumReduceDim>& reduceDims)
|
||||
{
|
||||
assert(NumReduceDim == reduceDims.size());
|
||||
|
||||
int reduceFlag = 0;
|
||||
|
||||
// flag the bits for the reduceDims
|
||||
@@ -18,13 +17,15 @@ std::vector<int> get_invariant_dims(const std::vector<int>& reduceDims)
|
||||
reduceFlag |= 1 << reduceDims[i];
|
||||
};
|
||||
|
||||
std::vector<int> invariantDims;
|
||||
std::array<int, Rank - NumReduceDim> invariantDims;
|
||||
|
||||
// collect invariant dimensions
|
||||
int dim = 0;
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) == 0)
|
||||
{
|
||||
invariantDims.push_back(i);
|
||||
invariantDims[dim] = i;
|
||||
dim++;
|
||||
};
|
||||
|
||||
return invariantDims;
|
||||
|
||||
Reference in New Issue
Block a user