Fix identity sigmoid activation (#659)

* activation support Identity

* fix Sigmoid activation operator() with CUTLASS_HOST_DEVICE
This commit is contained in:
seventh
2022-11-10 03:42:23 +08:00
committed by GitHub
parent 168ea8b0e1
commit 06eb90cc0d

View File

@@ -49,18 +49,6 @@ namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct Identity {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
T operator()(T value) const {
return value;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct LinearCombinationGenericParams {
@@ -95,6 +83,39 @@ struct LinearCombinationGenericParams {
/////////////////////////////////////////////////////////////////////////////////////////////////
// Identity operator
template <typename T>
struct Identity {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
T operator()(T value) const {
return value;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
T operator()(T const &value, Params const &params_) const {
return this->operator()(value);
}
};
template <typename T, int N>
struct Identity<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
return rhs;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
}
};
/// ReLu operator - propagates NaNs
/// Always put threshold in the right hand side of max to propagate NaN.
template <typename T>