Files
composable_kernel/device_operation/include/device_gemm_bias_activation.hpp
Chao Liu 823657ed12 GEMM+Bias+ReLU+Add (#76)
* tweak conv for odd C

* update script

* clean up elementwise op

* fix build

* clean up

* added example for gemm+bias+relu+add

* added example for gemm+bias+relu

* add profiler for gemm_s_shuffle; re-org files

* add profiler

* fix build

* clean up

* clean up

* clean up

* fix build
2022-02-06 22:32:47 -06:00

44 lines
1.9 KiB
C++

#ifndef DEVICE_GEMM_BIAS_ACTIVATION_HPP
#define DEVICE_GEMM_BIAS_ACTIVATION_HPP
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBiasActivation : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasActivationPtr = std::unique_ptr<
DeviceGemmBiasActivation<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif