example for convnd bwd weight bf16 splitk (#265)

* add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight

* add bwd weight for bf16: init

* remove redundant compute

* use datatype and split k to check whether a workspace is used

* remove unused computation for work space size

* add some code for bfp16

* add device/grid unary op

* add unary type convert to bwd-weight example

* support bf16 splitk kernel for convnd bwd weight

* 1. remove comments. 2. add checkvalidity. 3. add gridsize computation

* add workspace size check

* fix format

* change function name
This commit is contained in:
Shaojie WANG
2022-06-17 03:16:01 +08:00
committed by GitHub
parent fb9b6b1e33
commit 561ec12f4a
8 changed files with 1021 additions and 72 deletions

View File

@@ -346,6 +346,27 @@ struct UnarySqrt<double, double>
};
};
template <typename Y, typename X>
struct UnaryTypeConvert;
template <>
struct UnaryTypeConvert<float, ck::bhalf_t>
{
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
{
y = ck::type_convert<float, ck::bhalf_t>(x);
};
};
template <>
struct UnaryTypeConvert<ck::bhalf_t, float>
{
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
{
y = ck::type_convert<ck::bhalf_t, float>(x);
};
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck