mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
batched_gemm + multiple_d + gemm + multiple_d (#394)
* refactor * start * add device gemm file * add BatchStrideD0 * add stridd0 * add gridwise file * add d0 parameters to gridwise gemm * add c layout transformer * add d0 threadwise copy * init kernel * init kernel * regular code * nm desc put to out * kernel parameter can not use reference * host add bias+gelu * run right for bias+gelu * change AddFastGelu into another file * interface add d1 bias parameters * add d1 parameter to argument * add d1 parameter to gridwise * first all code,not verify * gelu change to relu and GetElementSpaceSize bug * add instance * start add to ckprofiler * ckprofiler finish code * change input parameter for ckProfiler * fix host bias+gelu bug * show help for ckProfiler * fix bug for lunch kernel ignore parametes * add pad and fix about bug * mutiple d0 * add dynamic d0_element_op * change profiler and instance to mutiple d0 * example have 2 d0 * remove some comments not using * change 2 d0 have self parameters * change d element_op name * change class name(multiple_d) * fix bug * fix bug that don't find file * update profiler * refactor * update profiler * clean * revert example change * add gon layout * optimize parameter for gno * add gon to gemm+gemm * change helping input parameters * change to GemmPadder_v2 * using ForEach * fix gb_per_sec Co-authored-by: Chao Liu <lc.roy86@gmail.com> Co-authored-by: ltqin <letaoqin@amd.com>
This commit is contained in:
@@ -28,6 +28,13 @@ struct Add
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 + type_convert<half_t>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
@@ -172,6 +179,14 @@ struct AddRelu
|
||||
const float a = x0 + x1;
|
||||
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
y = a > 0.0f ? a : 0.0f;
|
||||
};
|
||||
};
|
||||
|
||||
struct AddHardswish
|
||||
@@ -210,6 +225,46 @@ struct AddHardswish
|
||||
};
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
// E = FastGelu(C + D)
|
||||
struct AddFastGelu
|
||||
{
|
||||
// Fast GeLU
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
||||
__host__ __device__ static constexpr float GetFastGeLU(float x)
|
||||
{
|
||||
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float emu = exp(-u);
|
||||
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
|
||||
return x * cdf;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_valid_param_type_v =
|
||||
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
|
||||
{
|
||||
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
|
||||
is_valid_param_type_v<D>);
|
||||
|
||||
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
|
||||
|
||||
e = type_convert<E>(y);
|
||||
}
|
||||
|
||||
template <typename D>
|
||||
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
|
||||
{
|
||||
static_assert(is_valid_param_type_v<D>);
|
||||
|
||||
e = GetFastGeLU(c + type_convert<float>(d));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -211,6 +211,27 @@ struct FastGelu
|
||||
}
|
||||
};
|
||||
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+erf(x/sqrt(2)))
|
||||
struct Gelu
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<ck::half_t, ck::half_t>(ck::half_t& y,
|
||||
const ck::half_t& x) const
|
||||
{
|
||||
y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x))));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user