mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Add client example of grouped conv2d forward (data type: fp16) (#488)
* Rename example folder for GroupedConvFwdMultipleD * Unify example codes * Change target names * Add fp16 example for multiple d instance * Re-format common.hpp * Add interface 'DeviceGroupedConvFwd' * Use simpler interface * Move common conv params out * Rename conv fwd client example folder * Add missing include directive * Update grouped conv instance implementations * Simplify ckProfiler (grouped conv forward) * Use GroupedConvFwd to implement client example * Use greater groupe count in example * Add custom target to group examples * Add extra tag param to instance factory function * Use tag to differentiate factory functions * Add missing tag argument for factory function * Remove inheritance relationship * Remove no-longer used include directive * Add license in front of file
This commit is contained in:
@@ -14,39 +14,38 @@ namespace device {
|
||||
// Convolution Forward:
|
||||
// input : input image A[G, N, C, Hi, Wi],
|
||||
// input : weight B[G, K, C, Y, X],
|
||||
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
|
||||
// output : output image E[G, N, K, Ho, Wo]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceGroupedConvFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a, // input image
|
||||
const void* p_b, // weight
|
||||
void* p_c, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
|
||||
MakeArgumentPointer(const void* p_in, // input image
|
||||
const void* p_wei, // weight
|
||||
void* p_out, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op) = 0;
|
||||
const InElementwiseOperation& in_element_op,
|
||||
const WeiElementwiseOperation& wei_element_op,
|
||||
const OutElementwiseOperation& out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user