mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
BatchNorm backward instance/external API/profiler/tests (#519)
* Refine the device batchnorm-backward base API templates and data type assignments * Remove duplicated kernel file * Add batchnorm backward instances and external API * Add batchnorm-backward profiler and tests * Add client example which uses batchnorm backward external API * Merge test/batchnorm_fwd and test/batchnorm_bwd into one directory * Loose the threshold for batchnorm-backward check_err()
This commit is contained in:
@@ -27,7 +27,7 @@ template <typename XDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
index_t Rank,
|
||||
@@ -42,11 +42,19 @@ template <typename XDataType,
|
||||
index_t XSrcVectorSize,
|
||||
index_t DySrcVectorSize,
|
||||
index_t DxDstVectorSize,
|
||||
index_t ScaleSrcDstVectorSize,
|
||||
index_t BiasDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t DscaleDbiasDstVectorSize,
|
||||
index_t MeanVarSrcVectorSize>
|
||||
struct DeviceBatchNormBwdImpl
|
||||
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>
|
||||
struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const DyDataType* p_dy,
|
||||
@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
double epsilon,
|
||||
DxDataType* p_dx,
|
||||
ScaleDataType* p_dscale,
|
||||
BiasDataType* p_dbias)
|
||||
DscaleDbiasDataType* p_dscale,
|
||||
DscaleDbiasDataType* p_dbias)
|
||||
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
|
||||
bnScaleStrides_(bnScaleStrides),
|
||||
bnBiasStrides_(bnBiasStrides),
|
||||
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
|
||||
bnMeanVarStrides_(bnMeanVarStrides),
|
||||
p_x_(p_x),
|
||||
p_dy_(p_dy),
|
||||
@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl
|
||||
MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration);
|
||||
scale_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
|
||||
bias_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides);
|
||||
dscale_dbias_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
|
||||
mean_var_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
|
||||
}
|
||||
@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl
|
||||
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl
|
||||
const MeanVarDataType* p_savedInvVar_;
|
||||
const DyElementwiseOp dy_elementwise_op_;
|
||||
DxDataType* p_dx_;
|
||||
ScaleDataType* p_dscale_;
|
||||
BiasDataType* p_dbias_;
|
||||
DscaleDbiasDataType* p_dscale_;
|
||||
DscaleDbiasDataType* p_dbias_;
|
||||
|
||||
long_index_t invariant_length;
|
||||
long_index_t reduce_length;
|
||||
@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl
|
||||
XYGridDesc_M_K dy_grid_desc_m_k;
|
||||
XYGridDesc_M_K dx_grid_desc_m_k;
|
||||
ScaleBiasGridDesc_M scale_grid_desc_m;
|
||||
ScaleBiasGridDesc_M bias_grid_desc_m;
|
||||
ScaleBiasGridDesc_M dscale_dbias_grid_desc_m;
|
||||
MeanVarGridDesc_M mean_var_grid_desc_m;
|
||||
|
||||
void* workspace_mean;
|
||||
@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl
|
||||
{
|
||||
// workspace for the partial reduced result for dscale
|
||||
workspace_size +=
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType) + 64;
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
|
||||
|
||||
// workspace for the partial reduced result for dbias
|
||||
workspace_size +=
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType) + 64;
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
|
||||
|
||||
if(!pArg_->haveSavedMeanInvVar_)
|
||||
{
|
||||
@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl
|
||||
// setup buffer for the partial reduced result for dscale
|
||||
pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
|
||||
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType);
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for the partial reduced result for dbias
|
||||
@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize > 1)
|
||||
{
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType);
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for welford intermediate mean
|
||||
@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl
|
||||
XSrcVectorSize,
|
||||
DySrcVectorSize,
|
||||
DxDstVectorSize,
|
||||
ScaleSrcDstVectorSize,
|
||||
BiasDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
DscaleDbiasDstVectorSize,
|
||||
MeanVarSrcVectorSize>;
|
||||
|
||||
if(UseMultiblockInK && arg.blkGroupSize > 1)
|
||||
@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DyDataType,
|
||||
DxDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl
|
||||
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
|
||||
arg.p_x_,
|
||||
arg.p_dy_,
|
||||
static_cast<ScaleDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<BiasDataType*>(arg.workspace_reduce_dbias));
|
||||
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
|
||||
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl
|
||||
dscale_dbias_grid_desc_m_k,
|
||||
arg.mean_var_grid_desc_m,
|
||||
arg.scale_grid_desc_m,
|
||||
arg.bias_grid_desc_m,
|
||||
arg.dscale_dbias_grid_desc_m,
|
||||
arg.blkGroupSize,
|
||||
arg.reduce_length,
|
||||
arg.numBlockTileIteration,
|
||||
numDscaleDbiasBlockTileIteration,
|
||||
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias),
|
||||
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
|
||||
arg.haveSavedMeanInvVar_
|
||||
? arg.p_savedMean_
|
||||
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
|
||||
@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl
|
||||
XSrcVectorSize,
|
||||
DySrcVectorSize,
|
||||
DxDstVectorSize,
|
||||
ScaleSrcDstVectorSize,
|
||||
BiasDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
DscaleDbiasDstVectorSize,
|
||||
MeanVarSrcVectorSize>;
|
||||
|
||||
const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
|
||||
@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl
|
||||
arg.dy_grid_desc_m_k,
|
||||
arg.dx_grid_desc_m_k,
|
||||
arg.scale_grid_desc_m,
|
||||
arg.bias_grid_desc_m,
|
||||
arg.dscale_dbias_grid_desc_m,
|
||||
arg.mean_var_grid_desc_m,
|
||||
get_reduce_count_per_thread,
|
||||
arg.reduce_length,
|
||||
@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl
|
||||
return false;
|
||||
};
|
||||
|
||||
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcDstVectorSize != 1)
|
||||
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasDstVectorSize != 1)
|
||||
if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcDstVectorSize != 0)
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasDstVectorSize != 0)
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->haveSavedMeanInvVar_)
|
||||
@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_dy,
|
||||
@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleStrides,
|
||||
bnBiasStrides,
|
||||
bnDscaleDbiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const DyDataType*>(p_dy),
|
||||
@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl
|
||||
dy_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<DxDataType*>(p_dx),
|
||||
static_cast<ScaleDataType*>(p_dscale),
|
||||
static_cast<BiasDataType*>(p_dbias));
|
||||
static_cast<DscaleDbiasDataType*>(p_dscale),
|
||||
static_cast<DscaleDbiasDataType*>(p_dbias));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcDstVectorSize << "_bias_" << BiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
Reference in New Issue
Block a user