BatchNorm backward instance/external API/profiler/tests (#519)

* Refine the device batchnorm-backward base API templates and data type assignments

* Remove duplicated kernel file

* Add batchnorm backward instances and external API

* Add batchnorm-backward profiler and tests

* Add client example which uses batchnorm backward external API

* Merge test/batchnorm_fwd and test/batchnorm_bwd into one directory

* Loose the threshold for batchnorm-backward check_err()
This commit is contained in:
Qianfeng
2022-12-01 03:32:20 +08:00
committed by GitHub
parent 236bd148b9
commit 63af525c06
25 changed files with 2240 additions and 832 deletions

View File

@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename DyDataType,
typename DxDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
@@ -76,7 +76,7 @@ template <typename XDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
@@ -92,8 +92,8 @@ template <typename XDataType,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseReduceSecondHalfBatchNormBackwardFinal
{
@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m,
index_t blkgroup_size,
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
@@ -222,8 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
auto threadwise_dscale_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<DscaleDbiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
@@ -238,54 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dscale_store_m =
auto threadwise_dscale_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles)
{
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf);
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf);
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf);
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf);
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
});
threadwise_dscale_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
scale_grid_desc_m,
dscale_global_buf);
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
dscale_dbias_grid_desc_m,
dscale_global_buf);
threadwise_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
bias_grid_desc_m,
dbias_global_buf);
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
dscale_dbias_grid_desc_m,
dbias_global_buf);
// clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,

View File

@@ -17,7 +17,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
@@ -45,8 +45,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
@@ -76,7 +76,7 @@ template <typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
@@ -174,8 +174,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
@@ -511,28 +511,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
});
auto threadwise_dscale_store =
auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
@@ -557,17 +538,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
if(thread_k_cluster_id == 0)
{
threadwise_dscale_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf);
};
};
};

View File

@@ -21,7 +21,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
@@ -33,7 +33,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
@@ -47,14 +47,14 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
dx_grid_desc_m_k,
scale_grid_desc_m,
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
mean_var_grid_desc_m,
get_reduce_count_per_thread,
reduce_size,
@@ -77,7 +77,7 @@ template <typename XDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
@@ -93,8 +93,8 @@ template <typename XDataType,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseBatchNormBackwardWithBlockwiseWelford
{
@@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
@@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
using ck::math::sqrt;
@@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
@@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
InMemoryDataOperationEnum::Set,
1,
true>(
dy_grid_desc_m_k,
dx_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
@@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_dscale_store =
auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
@@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
// clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
@@ -487,17 +469,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
if(thread_k_cluster_id == 0)
{
threadwise_dscale_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
scale_grid_desc_m,
dscale_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
dscale_dbias_grid_desc_m,
dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
bias_grid_desc_m,
dbias_global_buf);
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
dscale_dbias_grid_desc_m,
dbias_global_buf);
};
// clang-format off

View File

@@ -1,258 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck