mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
* gemm+activation
* move C pointwise operation into threadwise copy
* add pointwise operation to A/B matrix
* update ckProfiler
* adding bias add
* adding bias add
* adding bias add
* added bias add; worked around compiler issues
* clean up
* clean up
* Update README.md
* Update README.md
* Update README.md
* clean up
* add conv_xdl example
* adding conv_xdl_bias_relu_add example
* add conv+bias+relu+add, but has register spill issue
* tweak
* tweak
* refactor
* Update README.md
update readme for example/2_gemm_xdl_bias_relu_add
* clean up
* Update README.md
update readme for example/3_conv_xdl
* Update README.md
[ROCm/composable_kernel commit: 41cdd3801a]
111 lines
4.6 KiB
C++
111 lines
4.6 KiB
C++
#ifndef DEVICE_CONV_HPP
|
|
#define DEVICE_CONV_HPP
|
|
|
|
#include <iostream>
|
|
#include "device_base.hpp"
|
|
|
|
namespace ck {
|
|
namespace tensor_operation {
|
|
namespace device {
|
|
|
|
template <typename InElementwiseOperation,
|
|
typename WeiElementwiseOperation,
|
|
typename OutElementwiseOperation>
|
|
struct DeviceConvFwd : public BaseOperator
|
|
{
|
|
virtual std::unique_ptr<BaseArgument>
|
|
MakeArgumentPointer(const void* p_in,
|
|
const void* p_wei,
|
|
void* p_out,
|
|
ck::index_t N,
|
|
ck::index_t K,
|
|
ck::index_t C,
|
|
std::vector<ck::index_t> input_spatial_lengths,
|
|
std::vector<ck::index_t> filter_spatial_lengths,
|
|
std::vector<ck::index_t> output_spatial_lengths,
|
|
std::vector<ck::index_t> conv_filter_strides,
|
|
std::vector<ck::index_t> conv_filter_dilations,
|
|
std::vector<ck::index_t> input_left_pads,
|
|
std::vector<ck::index_t> input_right_pads,
|
|
InElementwiseOperation in_element_op,
|
|
WeiElementwiseOperation wei_element_op,
|
|
OutElementwiseOperation out_element_op) = 0;
|
|
|
|
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
|
};
|
|
|
|
template <typename InElementwiseOperation,
|
|
typename WeiElementwiseOperation,
|
|
typename OutElementwiseOperation>
|
|
struct DeviceConvBwd : public BaseOperator
|
|
{
|
|
virtual std::unique_ptr<BaseArgument>
|
|
MakeArgumentPointer(void* p_in,
|
|
const void* p_wei,
|
|
const void* p_out,
|
|
ck::index_t N,
|
|
ck::index_t K,
|
|
ck::index_t C,
|
|
std::vector<ck::index_t> input_spatial_lengths,
|
|
std::vector<ck::index_t> filter_spatial_lengths,
|
|
std::vector<ck::index_t> output_spatial_lengths,
|
|
std::vector<ck::index_t> conv_filter_strides,
|
|
std::vector<ck::index_t> conv_filter_dilations,
|
|
std::vector<ck::index_t> input_left_pads,
|
|
std::vector<ck::index_t> input_right_pads,
|
|
InElementwiseOperation in_element_op,
|
|
WeiElementwiseOperation wei_element_op,
|
|
OutElementwiseOperation out_element_op) = 0;
|
|
|
|
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
|
};
|
|
|
|
template <typename InElementwiseOperation,
|
|
typename WeiElementwiseOperation,
|
|
typename OutElementwiseOperation>
|
|
struct DeviceConvWrw : public BaseOperator
|
|
{
|
|
virtual std::unique_ptr<BaseArgument>
|
|
MakeArgumentPointer(const void* p_in,
|
|
void* p_wei,
|
|
const void* p_out,
|
|
ck::index_t N,
|
|
ck::index_t K,
|
|
ck::index_t C,
|
|
std::vector<ck::index_t> input_spatial_lengths,
|
|
std::vector<ck::index_t> filter_spatial_lengths,
|
|
std::vector<ck::index_t> output_spatial_lengths,
|
|
std::vector<ck::index_t> conv_filter_strides,
|
|
std::vector<ck::index_t> conv_filter_dilations,
|
|
std::vector<ck::index_t> input_left_pads,
|
|
std::vector<ck::index_t> input_right_pads,
|
|
InElementwiseOperation in_element_op,
|
|
WeiElementwiseOperation wei_element_op,
|
|
OutElementwiseOperation out_element_op) = 0;
|
|
|
|
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
|
};
|
|
|
|
template <typename InElementwiseOperation,
|
|
typename WeiElementwiseOperation,
|
|
typename OutElementwiseOperation>
|
|
using DeviceConvFwdPtr = std::unique_ptr<
|
|
DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
|
|
|
template <typename InElementwiseOperation,
|
|
typename WeiElementwiseOperation,
|
|
typename OutElementwiseOperation>
|
|
using DeviceConvBwdPtr = std::unique_ptr<
|
|
DeviceConvBwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
|
|
|
template <typename InElementwiseOperation,
|
|
typename WeiElementwiseOperation,
|
|
typename OutElementwiseOperation>
|
|
using DeviceConvWrwPtr = std::unique_ptr<
|
|
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
|
|
|
} // namespace device
|
|
} // namespace tensor_operation
|
|
} // namespace ck
|
|
#endif
|