Pool3d fwd (#697)

* Expand the base class of pool2d, prepare to share base class with pool3d

* Add pool3d device op

* Add pool3d f16 example

* Refactor the base class. implement generic pooling in the future

* clang format

* get original index in max pooling

* Add outputindex to base class

* Fix dimension

* Add pooling instance

* Use indexType instead

* Remove useless header

* Extract IndexDataType to template

* Extract pooling reference code

* clang format

* clang format

* Fix typo

* Add tensor stride

* Add missing header

* Add index stride and output stride

* Refine naming

* Add type to base class

* Rename file

* Use proper size

* Fix typo

* Refine naming

* Modify the argument into vector.

* Add max pool profiler

* Refine naming

* Support f32 pool

* Fix typo

* Add avg pool2d fwd in profiler

* clang format

* Rename AccDatatype to ComputeDatatype

* Fix init

* test pool

* Extract variable

* Add client example

* Check the pooling dim

* clang format

* Connect argv and arg_parser

* Add found check

* Remove useless header

* Refine naming

* Adjust the order of device_pool_fwd

[ROCm/composable_kernel commit: 76ec0089fb]
This commit is contained in:
rocking
2023-05-24 22:05:04 +08:00
committed by GitHub
parent 96dbe62907
commit 84cbb3af35
44 changed files with 3226 additions and 241 deletions

View File

@@ -15,6 +15,7 @@ namespace ck {
template <typename GridwiseReduction,
bool OutputIndex,
bool TransformIndexKtoGlobal,
bool HaveIndexInput,
typename InDataType,
typename OutDataType,
@@ -48,16 +49,17 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
}
else
{
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_value_global,
p_in_index_global,
beta,
p_out_value_global,
p_out_index_global);
GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_value_global,
p_in_index_global,
beta,
p_out_value_global,
p_out_index_global);
};
};
@@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
};
template <bool HaveIndexInput>
template <bool TransformIndexKtoGlobal, bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
@@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise
indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
if constexpr(TransformIndexKtoGlobal)
{
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
const auto coord = make_tensor_coordinate(
in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize + I,
accu_index_buf(I)));
accu_index_buf(I) = coord.GetOffset();
});
}
};
// for indiced operation, acc_elementwise_op shoud do nothing