mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Pool3d fwd (#697)
* Expand the base class of pool2d, prepare to share base class with pool3d
* Add pool3d device op
* Add pool3d f16 example
* Refactor the base class. implement generic pooling in the future
* clang format
* get original index in max pooling
* Add outputindex to base class
* Fix dimension
* Add pooling instance
* Use indexType instead
* Remove useless header
* Extract IndexDataType to template
* Extract pooling reference code
* clang format
* clang format
* Fix typo
* Add tensor stride
* Add missing header
* Add index stride and output stride
* Refine naming
* Add type to base class
* Rename file
* Use proper size
* Fix typo
* Refine naming
* Modify the argument into vector.
* Add max pool profiler
* Refine naming
* Support f32 pool
* Fix typo
* Add avg pool2d fwd in profiler
* clang format
* Rename AccDatatype to ComputeDatatype
* Fix init
* test pool
* Extract variable
* Add client example
* Check the pooling dim
* clang format
* Connect argv and arg_parser
* Add found check
* Remove useless header
* Refine naming
* Adjust the order of device_pool_fwd
[ROCm/composable_kernel commit: 76ec0089fb]
This commit is contained in:
@@ -15,6 +15,7 @@ namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool OutputIndex,
|
||||
bool TransformIndexKtoGlobal,
|
||||
bool HaveIndexInput,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
@@ -48,16 +49,17 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_value_global,
|
||||
p_in_index_global,
|
||||
beta,
|
||||
p_out_value_global,
|
||||
p_out_index_global);
|
||||
GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_value_global,
|
||||
p_in_index_global,
|
||||
beta,
|
||||
p_out_value_global,
|
||||
p_out_index_global);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
|
||||
};
|
||||
|
||||
template <bool HaveIndexInput>
|
||||
template <bool TransformIndexKtoGlobal, bool HaveIndexInput>
|
||||
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
@@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
indexStart += KThreadSliceSize;
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
|
||||
if constexpr(TransformIndexKtoGlobal)
|
||||
{
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
const auto coord = make_tensor_coordinate(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize + I,
|
||||
accu_index_buf(I)));
|
||||
|
||||
accu_index_buf(I) = coord.GetOffset();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
|
||||
Reference in New Issue
Block a user