mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
example for convnd bwd weight bf16 splitk (#265)
* add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight * add bwd weight for bf16: init * remove redundant compute * use datatype and split k to check whether a workspace is used * remove unused computation for work space size * add some code for bfp16 * add device/grid unary op * add unary type convert to bwd-weight example * support bf16 splitk kernel for convnd bwd weight * 1. remove comments. 2. add checkvalidity. 3. add gridsize computation * add workspace size check * fix format * change function name
This commit is contained in:
@@ -346,6 +346,27 @@ struct UnarySqrt<double, double>
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
struct UnaryTypeConvert;
|
||||
|
||||
template <>
|
||||
struct UnaryTypeConvert<float, ck::bhalf_t>
|
||||
{
|
||||
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
|
||||
{
|
||||
y = ck::type_convert<float, ck::bhalf_t>(x);
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryTypeConvert<ck::bhalf_t, float>
|
||||
{
|
||||
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
|
||||
{
|
||||
y = ck::type_convert<ck::bhalf_t, float>(x);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user