mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Update to the Reduction API and instances (#476)
* Simplify the macros for declaring and defining the add_device_reduce_instance_xxxx() instances * Change the types of lengths and strides from std::vector to std::array for the reduction device interfaces * Remove DeviceSoftmaxImpl's depending on DeviceReduceMultiblock * Split the cpp and hpp files for reduction instances to enable more parallel compiling * Remove the using of macros for declaring reduction instances and instance references * Update to add_device_reduce_instance_xxxx templated functions * Use ReduceOperation+InElementwiseOp+AccElementwiseOp to repace the ReduceOpId in defining add_reduce_instance_xxxx() templates * Change return format
This commit is contained in:
@@ -90,15 +90,15 @@ static bool time_kernel;
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// used by the device reduction
|
||||
const std::vector<int> reduceDims_1 = {4};
|
||||
const std::vector<int> invariantDims_1 = {0, 1, 2, 3};
|
||||
const std::array<int, 1> reduceDims_1 = {4};
|
||||
// const std::array<int, 4> invariantDims_1 = {0, 1, 2, 3};
|
||||
|
||||
const std::vector<int> reduceDims_2 = {3};
|
||||
const std::vector<int> invariantDims_2 = {0, 1, 2};
|
||||
const std::array<int, 1> reduceDims_2 = {3};
|
||||
// const std::array<int, 3> invariantDims_2 = {0, 1, 2};
|
||||
|
||||
// used by the host reduction
|
||||
const std::vector<int> reduceDims = {3, 4};
|
||||
const std::vector<int> invariantDims = {0, 1, 2};
|
||||
const std::array<int, 2> reduceDims = {3, 4};
|
||||
const std::array<int, 3> invariantDims = {0, 1, 2};
|
||||
|
||||
const std::vector<size_t> inLengths_1 = {64, 320, 80, 4, 128};
|
||||
|
||||
@@ -214,26 +214,26 @@ int main(int argc, char* argv[])
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths_1;
|
||||
std::vector<ck::index_t> i_inStrides_1;
|
||||
std::vector<ck::index_t> i_inLengths_2;
|
||||
std::vector<ck::index_t> i_inStrides_2;
|
||||
std::vector<ck::index_t> i_outLengths;
|
||||
std::vector<ck::index_t> i_outStrides;
|
||||
std::array<index_t, 5> arrInLengths_1;
|
||||
std::array<index_t, 5> arrInStrides_1;
|
||||
std::array<index_t, 4> arrInLengths_2;
|
||||
std::array<index_t, 4> arrInStrides_2;
|
||||
std::array<index_t, 3> arrOutLengths;
|
||||
std::array<index_t, 3> arrOutStrides;
|
||||
|
||||
i_inLengths_1.assign(inLengths_1.begin(), inLengths_1.end());
|
||||
i_inStrides_1.assign(inStrides_1.begin(), inStrides_1.end());
|
||||
i_inLengths_2.assign(inLengths_2.begin(), inLengths_2.end());
|
||||
i_inStrides_2.assign(inStrides_2.begin(), inStrides_2.end());
|
||||
i_outLengths.assign(outLengths.begin(), outLengths.end());
|
||||
i_outStrides.assign(outStrides.begin(), outStrides.end());
|
||||
std::copy(inLengths_1.begin(), inLengths_1.end(), arrInLengths_1.begin());
|
||||
std::copy(inStrides_1.begin(), inStrides_1.end(), arrInStrides_1.begin());
|
||||
std::copy(inLengths_2.begin(), inLengths_2.end(), arrInLengths_2.begin());
|
||||
std::copy(inStrides_2.begin(), inStrides_2.end(), arrInStrides_2.begin());
|
||||
std::copy(outLengths.begin(), outLengths.end(), arrOutLengths.begin());
|
||||
std::copy(outStrides.begin(), outStrides.end(), arrOutStrides.begin());
|
||||
|
||||
auto reduce_1 = DeviceReduceInstance_1{};
|
||||
|
||||
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(i_inLengths_1,
|
||||
i_inStrides_1,
|
||||
i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(arrInLengths_1,
|
||||
arrInStrides_1,
|
||||
arrInLengths_2,
|
||||
arrInStrides_2,
|
||||
reduceDims_1,
|
||||
1.0f,
|
||||
0.0f,
|
||||
@@ -255,10 +255,10 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto reduce_2 = DeviceReduceInstance_2{};
|
||||
|
||||
auto argument_ptr_2 = reduce_2.MakeArgumentPointer(i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
auto argument_ptr_2 = reduce_2.MakeArgumentPointer(arrInLengths_2,
|
||||
arrInStrides_2,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims_2,
|
||||
alpha,
|
||||
beta,
|
||||
|
||||
Reference in New Issue
Block a user