mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
BatchNorm forward instance/external api/profiler/tests/client example (#511)
* Update to device_batchnorm_forward base class to include all template parameters for problem description * Add batchnorm forward instances and external api * Add batchnorm forward profiler module which uses the external api * Add some comments in batchnorm_forward example to explain the dimensions in lengths[] * Replace the reference_batchnorm_forward_nhwc_c by generic reference_batchnorm_forward * Improvement to the batchnorm infer base API * Add batchnorm forward client example which shows using the batchnorm forward external API * Add test for batchnorm forward * Tuning the batchnorm profiler initialized values and error threshold * Add support for bhalf_t in instances/external api/tests * Add support for int8_t in instances/external api/tests * Add support for double in instances/external api/tests * Let ScaleDataType and BiasDataType be same as XDataType and YDataType when creating instances * Checking before running best instance in batchnorm_fwd_nhwc client example * Add checking for YElementwiseOp in batchnorm_forward external API * Add more types in batchnorm forward profiler * Add more test lengths Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -13,7 +13,15 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
struct DeviceBatchNormFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
|
||||
using DeviceBatchNormFwdPtr =
|
||||
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -13,13 +13,22 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim>
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
struct DeviceBatchNormInfer : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator
|
||||
const void* bnScale,
|
||||
const void* bnBias,
|
||||
double epsilon,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
const void* estimatedMean,
|
||||
const void* estimatedInvVariance,
|
||||
void* p_y) = 0;
|
||||
@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<Rank, NumBatchNormReduceDim>>;
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -42,8 +42,15 @@ template <typename XDataType,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t BiasSrcVectorSize,
|
||||
index_t MeanVarSrcDstVectorSize>
|
||||
struct DeviceBatchNormFwdImpl
|
||||
: public DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>
|
||||
struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
|
||||
Reference in New Issue
Block a user