mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
BatchNorm forward instance/external api/profiler/tests/client example (#511)
* Update to device_batchnorm_forward base class to include all template parameters for problem description * Add batchnorm forward instances and external api * Add batchnorm forward profiler module which uses the external api * Add some comments in batchnorm_forward example to explain the dimensions in lengths[] * Replace the reference_batchnorm_forward_nhwc_c by generic reference_batchnorm_forward * Improvement to the batchnorm infer base API * Add batchnorm forward client example which shows using the batchnorm forward external API * Add test for batchnorm forward * Tuning the batchnorm profiler initialized values and error threshold * Add support for bhalf_t in instances/external api/tests * Add support for int8_t in instances/external api/tests * Add support for double in instances/external api/tests * Let ScaleDataType and BiasDataType be same as XDataType and YDataType when creating instances * Checking before running best instance in batchnorm_fwd_nhwc client example * Add checking for YElementwiseOp in batchnorm_forward external API * Add more types in batchnorm forward profiler * Add more test lengths Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
|
||||
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
|
||||
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
|
||||
|
||||
// input data of the batchnorm forward algorithm
|
||||
@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2},
|
||||
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
{
|
||||
|
||||
using ReferenceBatchNormFwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp>;
|
||||
ck::tensor_operation::host::ReferenceBatchNormFwd<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
|
||||
|
||||
@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2},
|
||||
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp"
|
||||
|
||||
#include "batchnorm_infer_impl.hpp"
|
||||
|
||||
@@ -124,6 +125,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
|
||||
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
|
||||
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
|
||||
|
||||
// input data of the batchnorm forward algorithm
|
||||
@@ -260,20 +263,25 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ReferenceBatchNormInferInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<
|
||||
InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType>;
|
||||
ck::tensor_operation::host::ReferenceBatchNormInfer<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
|
||||
|
||||
auto argument_ptr_ref =
|
||||
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2},
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
@@ -282,6 +290,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
bnScale.mData.data(),
|
||||
bnBias.mData.data(),
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
estimatedMean.mData.data(),
|
||||
estimatedVariance.mData.data(),
|
||||
y_ref.mData.data());
|
||||
|
||||
Reference in New Issue
Block a user