mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
BatchNorm backward implementation (#461)
* Implemented batchnorm-backward Blockwise and Multiblock kernels * Add batchnorm-backward device op * Add batchnorm-backward host-reference op * Add batchnorm-backward example * Parameters renaming in batchnorm backward kernels and device op * Change in the example to loose the threshold for ScaleDiff checking * Add comments to explain the implementation of batchnorm-backward * Parameters renaming again in batchnorm backward kernels * Improve the expression calculation for performance * Add batchnorm backward to README * Add comments to explain inv-variance in batchnorm forward and backward * Renaming the batchnorm forward training and inferring examples * Add/update the comments for batchnorm-backward kernels * Renaming again * Add block_sync_lds between two consecutive blockwise reductions * Move common expression 1/N out of the static_for loops * Add dy_elementwise_op * Renaming in backward example again * Add checking for reduceDims in reference_batchnorm_backward * Update to comments and codes format * Rename in the comments * Remove common expression out of the loop in reference_batchnorm_backward_nhwc_c * Add block_sync_lds() between blockwise reduction again * Fix comments again * Remove int8 from batchnorm-forward instances since it is not needed for forward training and could fail test
This commit is contained in:
@@ -0,0 +1,534 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
|
||||
typename XDataType,
|
||||
typename DyDataType,
|
||||
typename DxDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename DscaleDbiasGridDesc_M_K,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename ScaleBiasGridDesc_M>
|
||||
__global__ void kernel_reduce_second_half_batchnorm_backward_final(
|
||||
const XYGridDesc_M_K x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K dy_grid_desc_m_k,
|
||||
const XYGridDesc_M_K dx_grid_desc_m_k,
|
||||
const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k,
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M bias_grid_desc_m,
|
||||
index_t blkgroup_size,
|
||||
long_index_t reduce_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_dscale_dbias_k_block_tile_iteration,
|
||||
const ScaleDataType* const __restrict__ p_reduce_dscale,
|
||||
const BiasDataType* const __restrict__ p_reduce_dbias,
|
||||
const MeanVarDataType* const __restrict__ p_mean,
|
||||
const MeanVarDataType* const __restrict__ p_inv_var,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const DyDataType* const __restrict__ p_dy,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* const __restrict__ p_dx,
|
||||
ScaleDataType* const __restrict__ p_dscale,
|
||||
BiasDataType* const __restrict__ p_dbias)
|
||||
{
|
||||
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
|
||||
dy_grid_desc_m_k,
|
||||
dx_grid_desc_m_k,
|
||||
dscale_dbias_grid_desc_m_k,
|
||||
mean_var_grid_desc_m,
|
||||
scale_grid_desc_m,
|
||||
bias_grid_desc_m,
|
||||
blkgroup_size,
|
||||
reduce_size,
|
||||
num_xy_k_block_tile_iteration,
|
||||
num_dscale_dbias_k_block_tile_iteration,
|
||||
p_reduce_dscale,
|
||||
p_reduce_dbias,
|
||||
p_mean,
|
||||
p_inv_var,
|
||||
p_x,
|
||||
p_dy,
|
||||
p_scale,
|
||||
dy_elementwise_op,
|
||||
p_dx,
|
||||
p_dscale,
|
||||
p_dbias);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename DyDataType,
|
||||
typename DxDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename DscaleDbiasGridDesc_M_K,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename ScaleBiasGridDesc_M,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XDyDxVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t DySrcVectorSize,
|
||||
index_t DxDstVectorSize,
|
||||
index_t ScaleSrcDstVectorSize,
|
||||
index_t BiasDstVectorSize,
|
||||
index_t MeanVarSrcVectorSize>
|
||||
struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
{
|
||||
static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
|
||||
MThreadSliceSize % DySrcVectorSize == 0 &&
|
||||
MThreadSliceSize % DxDstVectorSize == 0) ||
|
||||
(XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
|
||||
KThreadSliceSize % DySrcVectorSize == 0 &&
|
||||
KThreadSliceSize % DxDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ck::reduce::Add,
|
||||
false>;
|
||||
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_1,
|
||||
ThreadReduceDstDesc_M,
|
||||
ck::reduce::Add,
|
||||
false>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
// clang-format off
|
||||
// Two of the steps of Multiblock BatchNorm Backward
|
||||
// Step 1: Second half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
|
||||
// Step 2: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
|
||||
// clang-format on
|
||||
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K& dy_grid_desc_m_k,
|
||||
const XYGridDesc_M_K& dx_grid_desc_m_k,
|
||||
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
|
||||
const MeanVarGridDesc_M& mean_var_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M& scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M& bias_grid_desc_m,
|
||||
index_t blkgroup_size,
|
||||
long_index_t reduce_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_dscale_dbias_k_block_tile_iteration,
|
||||
const ScaleDataType* const __restrict__ p_reduce_dscale,
|
||||
const BiasDataType* const __restrict__ p_reduce_dbias,
|
||||
const MeanVarDataType* const __restrict__ p_mean,
|
||||
const MeanVarDataType* const __restrict__ p_inv_var,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const DyDataType* const __restrict__ p_dy,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* const __restrict__ p_dx,
|
||||
ScaleDataType* const __restrict__ p_dscale,
|
||||
BiasDataType* const __restrict__ p_dbias)
|
||||
{
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
|
||||
reduce_dscale_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
|
||||
reduce_dbias_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
dy_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
dx_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
inv_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / blkgroup_size;
|
||||
const index_t block_local_id = block_global_id % blkgroup_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
|
||||
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
// clang-format off
|
||||
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
|
||||
// clang-format on
|
||||
|
||||
auto threadwise_dscale_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
|
||||
AccDataType,
|
||||
DscaleDbiasGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
dscale_dbias_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
auto threadwise_dbias_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
|
||||
AccDataType,
|
||||
DscaleDbiasGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
dscale_dbias_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
auto threadwise_dscale_store_m =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ScaleDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ScaleBiasGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dbias_store_m =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
BiasDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ScaleBiasGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BiasDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
bias_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
const auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_reduce_dscale, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
constexpr auto dscale_dbias_thread_copy_step_m_k =
|
||||
make_multi_index(0, KThreadClusterSize * 1);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
dscale_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
dbias_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
|
||||
++reducedTiles)
|
||||
{
|
||||
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
|
||||
reduce_dscale_global_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dscale_thread_buf);
|
||||
|
||||
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
|
||||
reduce_dbias_global_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dbias_thread_buf);
|
||||
|
||||
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
|
||||
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
|
||||
|
||||
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
|
||||
dscale_dbias_thread_copy_step_m_k);
|
||||
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
|
||||
dscale_dbias_thread_copy_step_m_k);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
|
||||
block_sync_lds();
|
||||
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
|
||||
});
|
||||
|
||||
threadwise_dscale_store_m.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dscale_thread_buf,
|
||||
scale_grid_desc_m,
|
||||
dscale_global_buf);
|
||||
|
||||
threadwise_dbias_store_m.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dbias_thread_buf,
|
||||
bias_grid_desc_m,
|
||||
dbias_global_buf);
|
||||
|
||||
// clang-format off
|
||||
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
|
||||
// clang-format on
|
||||
|
||||
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
XYGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XDyDxVectorDim,
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
workSizePerBlock * block_local_id +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
|
||||
AccDataType,
|
||||
XYGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XDyDxVectorDim,
|
||||
DySrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
dy_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
workSizePerBlock * block_local_id +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
DxDataType,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
XYGridDesc_M_K,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XDyDxVectorDim,
|
||||
DxDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
dx_grid_desc_m_k,
|
||||
make_multi_index(
|
||||
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_scale_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
|
||||
AccDataType,
|
||||
ScaleBiasGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcDstVectorSize,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
auto threadwise_mean_var_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
|
||||
AccDataType,
|
||||
MeanVarGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_scale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_mean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_inv_var, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
threadwise_scale_load.Run(scale_grid_desc_m,
|
||||
scale_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
scale_thread_buf);
|
||||
|
||||
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
|
||||
mean_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
mean_thread_buf);
|
||||
|
||||
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
|
||||
inv_var_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
inv_var_thread_buf);
|
||||
|
||||
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
AccDataType inv_reduce_size =
|
||||
type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size);
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_dy_load.Run(dy_grid_desc_m_k,
|
||||
dy_global_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
dy_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType multiplier =
|
||||
inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM];
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
|
||||
dy_thread_buf[Number<offset>{}]);
|
||||
|
||||
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
|
||||
inv_var_thread_buf[iM];
|
||||
|
||||
AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
|
||||
|
||||
dx_thread_buf(Number<offset>{}) =
|
||||
multiplier *
|
||||
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
|
||||
dbias_thread_buf[iM] - tmpVal);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_dx_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
dx_thread_buf,
|
||||
dx_grid_desc_m_k,
|
||||
dx_global_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
|
||||
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
|
||||
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, xy_thread_copy_step_m_k);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
// clang-format off
|
||||
// First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
|
||||
// clang-format on
|
||||
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
|
||||
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
|
||||
|
||||
@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
|
||||
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
// calculate inv-variance as 1/sqrt(epsilon+variance)
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_var_thread_buf(I) =
|
||||
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
|
||||
|
||||
@@ -0,0 +1,575 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
|
||||
typename XDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename MeanVarCountGridDesc_M_K,
|
||||
typename DscaleDbiasGridDesc_M_G>
|
||||
__global__ void kernel_welford_second_half_reduce_first_half(
|
||||
const XYGridDesc_M_K x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K dy_grid_desc_m_k,
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
|
||||
const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g,
|
||||
index_t blkgroup_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_mean_var_count_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
bool haveSavedMeanInvVar,
|
||||
const MeanVarDataType* const __restrict__ p_savedMean,
|
||||
const MeanVarDataType* const __restrict__ p_savedInvVar,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_mean,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_variance,
|
||||
const int32_t* const __restrict__ p_in_welford_count,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
MeanVarDataType* const __restrict__ p_out_welford_mean,
|
||||
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const DyDataType* const __restrict__ p_dy,
|
||||
ScaleDataType* const __restrict__ p_reduce_dscale,
|
||||
BiasDataType* const __restrict__ p_reduce_dbias)
|
||||
{
|
||||
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
|
||||
dy_grid_desc_m_k,
|
||||
mean_var_grid_desc_m,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
blkgroup_size,
|
||||
num_xy_k_block_tile_iteration,
|
||||
num_mean_var_count_k_block_tile_iteration,
|
||||
epsilon,
|
||||
haveSavedMeanInvVar,
|
||||
p_savedMean,
|
||||
p_savedInvVar,
|
||||
p_in_welford_mean,
|
||||
p_in_welford_variance,
|
||||
p_in_welford_count,
|
||||
dy_elementwise_op,
|
||||
p_out_welford_mean,
|
||||
p_out_welford_inv_variance,
|
||||
p_x,
|
||||
p_dy,
|
||||
p_reduce_dscale,
|
||||
p_reduce_dbias);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename MeanVarCountGridDesc_M_K,
|
||||
typename DscaleDbiasGridDesc_M_G,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XDyVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t DySrcVectorSize,
|
||||
index_t MeanVarSrcVectorSize>
|
||||
struct GridwiseWelfordSecondHalfReduceFirstHalf
|
||||
{
|
||||
static_assert((XDyVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
|
||||
MThreadSliceSize % DySrcVectorSize == 0) ||
|
||||
(XDyVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
|
||||
KThreadSliceSize % DySrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XDyVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder>;
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ck::reduce::Add,
|
||||
false>;
|
||||
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ck::reduce::Add,
|
||||
false>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
// clang-format off
|
||||
// Two of the steps of Multiblock BatchNorm Backward
|
||||
// Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance)
|
||||
// Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
|
||||
// clang-format on
|
||||
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K& dy_grid_desc_m_k,
|
||||
const MeanVarGridDesc_M& mean_var_grid_desc_m,
|
||||
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
|
||||
const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g,
|
||||
index_t blkgroup_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_mean_var_count_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
bool haveSavedMeanInvVar,
|
||||
const MeanVarDataType* const __restrict__ p_savedMean,
|
||||
const MeanVarDataType* const __restrict__ p_savedInvVar,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_mean,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_variance,
|
||||
const int32_t* const __restrict__ p_in_welford_count,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
MeanVarDataType* const __restrict__ p_out_welford_mean,
|
||||
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const DyDataType* const __restrict__ p_dy,
|
||||
ScaleDataType* const __restrict__ p_reduce_dscale,
|
||||
BiasDataType* const __restrict__ p_reduce_dbias)
|
||||
{
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
|
||||
in_welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
|
||||
in_welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize * 1, true>
|
||||
in_welford_count_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
|
||||
welford_count_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& mean_thread_buf =
|
||||
welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>&
|
||||
inv_var_thread_buf = welford_var_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
dy_thread_buf;
|
||||
|
||||
// buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
tmp1_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
reduce_dscale_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
reduce_dbias_thread_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / blkgroup_size;
|
||||
const index_t block_local_id = block_global_id % blkgroup_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
|
||||
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
// clang-format off
|
||||
// Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance)
|
||||
// clang-format on
|
||||
|
||||
if(haveSavedMeanInvVar)
|
||||
{
|
||||
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_mean_inv_var_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
|
||||
AccDataType,
|
||||
MeanVarGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
|
||||
mean_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
mean_thread_buf);
|
||||
|
||||
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
|
||||
inv_var_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
inv_var_thread_buf);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto welford_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto welford_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto welford_count_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_mean_var_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<AccDataType,
|
||||
AccDataType,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
auto threadwise_count_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<int32_t,
|
||||
int32_t,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
constexpr auto mean_var_count_thread_copy_step_m_k =
|
||||
make_multi_index(0, KThreadClusterSize * 1);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
welford_count_thread_buf(I) = 0;
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
|
||||
++reducedTiles)
|
||||
{
|
||||
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
welford_mean_global_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_mean_thread_buf);
|
||||
|
||||
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
welford_var_global_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_var_thread_buf);
|
||||
|
||||
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
welford_count_global_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_count_thread_buf);
|
||||
|
||||
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
|
||||
in_welford_var_thread_buf,
|
||||
in_welford_count_thread_buf,
|
||||
welford_mean_thread_buf,
|
||||
welford_var_thread_buf,
|
||||
welford_count_thread_buf);
|
||||
|
||||
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(
|
||||
mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k);
|
||||
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
|
||||
mean_var_count_thread_copy_step_m_k);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseWelford::Run(welford_mean_thread_buf(I),
|
||||
welford_var_thread_buf(I),
|
||||
welford_count_thread_buf(I));
|
||||
});
|
||||
|
||||
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_var_thread_buf(I) =
|
||||
type_convert<AccDataType>(1.0) / sqrt(welford_var_thread_buf[I] + epsilon);
|
||||
});
|
||||
|
||||
if(block_local_id == 0 && thread_k_cluster_id == 0)
|
||||
{
|
||||
|
||||
auto threadwise_mean_inv_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
MeanVarGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
mean_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
mean_global_buf);
|
||||
|
||||
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
inv_var_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
inv_var_global_buf);
|
||||
};
|
||||
};
|
||||
|
||||
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
XYGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XDyVectorDim,
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
workSizePerBlock * block_local_id +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
|
||||
AccDataType,
|
||||
XYGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XDyVectorDim,
|
||||
DySrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
dy_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
workSizePerBlock * block_local_id +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
reduce_dscale_thread_buf(I) = type_convert<AccDataType>(0);
|
||||
reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
|
||||
});
|
||||
|
||||
// clang-format off
|
||||
// Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
|
||||
// clang-format on
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_dy_load.Run(dy_grid_desc_m_k,
|
||||
dy_global_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
dy_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
|
||||
dy_thread_buf[Number<offset>{}]);
|
||||
|
||||
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
|
||||
inv_var_thread_buf[iM];
|
||||
|
||||
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);
|
||||
ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
|
||||
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
|
||||
};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I));
|
||||
block_sync_lds();
|
||||
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
|
||||
});
|
||||
|
||||
auto threadwise_dscale_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ScaleDataType,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
DscaleDbiasGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dbias_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
BiasDataType,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
DscaleDbiasGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
threadwise_dscale_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dscale_thread_buf,
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
reduce_dscale_global_buf);
|
||||
|
||||
threadwise_dbias_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dbias_thread_buf,
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
reduce_dbias_global_buf);
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,572 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
|
||||
typename XDataType,
|
||||
typename DyDataType,
|
||||
typename DxDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename ScaleBiasGridDesc_M,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename GetReduceCountPerThreadFunctor>
|
||||
__global__ void kernel_batchnorm_backward_with_blockwise_welford(
|
||||
const XYGridDesc_M_K x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K dy_grid_desc_m_k,
|
||||
const XYGridDesc_M_K dx_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M bias_grid_desc_m,
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
|
||||
long_index_t reduce_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const DyDataType* const __restrict__ p_dy,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
bool haveSavedMeanInvVar,
|
||||
const MeanVarDataType* const __restrict__ p_savedMean,
|
||||
const MeanVarDataType* const __restrict__ p_savedInvVar,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* const __restrict__ p_dx,
|
||||
ScaleDataType* const __restrict__ p_dscale,
|
||||
BiasDataType* const __restrict__ p_dbias)
|
||||
{
|
||||
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
|
||||
dy_grid_desc_m_k,
|
||||
dx_grid_desc_m_k,
|
||||
scale_grid_desc_m,
|
||||
bias_grid_desc_m,
|
||||
mean_var_grid_desc_m,
|
||||
get_reduce_count_per_thread,
|
||||
reduce_size,
|
||||
num_k_block_tile_iteration,
|
||||
epsilon,
|
||||
p_x,
|
||||
p_dy,
|
||||
p_scale,
|
||||
haveSavedMeanInvVar,
|
||||
p_savedMean,
|
||||
p_savedInvVar,
|
||||
dy_elementwise_op,
|
||||
p_dx,
|
||||
p_dscale,
|
||||
p_dbias);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename DyDataType,
|
||||
typename DxDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename ScaleBiasGridDesc_M,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename GetReduceCountPerThreadFunctor,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XDyDxVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t DySrcVectorSize,
|
||||
index_t DxDstVectorSize,
|
||||
index_t ScaleSrcDstVectorSize,
|
||||
index_t BiasDstVectorSize,
|
||||
index_t MeanVarSrcVectorSize>
|
||||
struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
{
|
||||
static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
|
||||
MThreadSliceSize % DySrcVectorSize == 0 &&
|
||||
MThreadSliceSize % DxDstVectorSize == 0) ||
|
||||
(XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
|
||||
KThreadSliceSize % DySrcVectorSize == 0 &&
|
||||
KThreadSliceSize % DxDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder>;
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ck::reduce::Add,
|
||||
false>;
|
||||
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ck::reduce::Add,
|
||||
false>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
// clang-format off
|
||||
// Blockwise BatchNorm Backward
|
||||
// Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size
|
||||
// Output: dx, dscale, dbias
|
||||
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
|
||||
// Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
|
||||
// Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
|
||||
// clang-format on
|
||||
__device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K dy_grid_desc_m_k,
|
||||
const XYGridDesc_M_K dx_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M bias_grid_desc_m,
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
|
||||
long_index_t reduce_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const DyDataType* const __restrict__ p_dy,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
bool haveSavedMeanInvVar,
|
||||
const MeanVarDataType* const __restrict__ p_savedMean,
|
||||
const MeanVarDataType* const __restrict__ p_savedInvVar,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* const __restrict__ p_dx,
|
||||
ScaleDataType* const __restrict__ p_dscale,
|
||||
BiasDataType* const __restrict__ p_dbias)
|
||||
{
|
||||
using ck::math::sqrt;
|
||||
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
dy_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
dx_thread_buf;
|
||||
|
||||
// buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
tmp1_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>&
|
||||
inv_var_thread_buf = var_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
XYGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XDyDxVectorDim,
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
|
||||
AccDataType,
|
||||
XYGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XDyDxVectorDim,
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
DxDataType,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
XYGridDesc_M_K,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XDyDxVectorDim,
|
||||
DxDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
dy_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_scale_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
|
||||
AccDataType,
|
||||
ScaleBiasGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcDstVectorSize,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
auto threadwise_dscale_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ScaleDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ScaleBiasGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dbias_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
BiasDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ScaleBiasGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BiasDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
bias_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
|
||||
|
||||
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_scale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
// clang-format off
|
||||
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
|
||||
// clang-format on
|
||||
|
||||
if(haveSavedMeanInvVar)
|
||||
{
|
||||
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_mean_inv_var_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
|
||||
AccDataType,
|
||||
MeanVarGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
|
||||
mean_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
mean_thread_buf);
|
||||
|
||||
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
|
||||
inv_var_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
inv_var_thread_buf);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
int count = threadwise_welford.cur_count_;
|
||||
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
|
||||
});
|
||||
|
||||
// calculate inv-variance as 1/sqrt(epsilon+variance)
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
inv_var_thread_buf(I) =
|
||||
type_convert<AccDataType>(1.0) / sqrt(var_thread_buf[I] + epsilon);
|
||||
});
|
||||
|
||||
threadwise_x_load.SetSrcSliceOrigin(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance)
|
||||
// clang-format on
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
dscale_thread_buf(I) = type_convert<AccDataType>(0);
|
||||
dbias_thread_buf(I) = type_convert<AccDataType>(0);
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_dy_load.Run(dx_grid_desc_m_k,
|
||||
dy_global_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
dy_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
|
||||
dy_thread_buf[Number<offset>{}]);
|
||||
|
||||
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
|
||||
inv_var_thread_buf[iM];
|
||||
|
||||
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf);
|
||||
ThreadwiseReduce::Reduce(dy_thread_buf, dbias_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
|
||||
block_sync_lds();
|
||||
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
threadwise_dscale_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dscale_thread_buf,
|
||||
scale_grid_desc_m,
|
||||
dscale_global_buf);
|
||||
|
||||
threadwise_dbias_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dbias_thread_buf,
|
||||
bias_grid_desc_m,
|
||||
dbias_global_buf);
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
|
||||
// clang-format on
|
||||
|
||||
threadwise_scale_load.Run(scale_grid_desc_m,
|
||||
scale_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
scale_thread_buf);
|
||||
|
||||
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
|
||||
AccDataType inv_reduce_size =
|
||||
type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size);
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_dy_load.Run(dy_grid_desc_m_k,
|
||||
dy_global_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
dy_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType multiplier =
|
||||
inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM];
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
|
||||
dy_thread_buf[Number<offset>{}]);
|
||||
|
||||
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
|
||||
inv_var_thread_buf[iM];
|
||||
|
||||
AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
|
||||
|
||||
dx_thread_buf(Number<offset>{}) =
|
||||
multiplier *
|
||||
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
|
||||
dbias_thread_buf[iM] - tmpVal);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_dx_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
dx_thread_buf,
|
||||
dx_grid_desc_m_k,
|
||||
dx_global_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford
|
||||
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
var_thread_buf(I) =
|
||||
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseMultiblockWelfordFirstHalf_,
|
||||
typename XDataType,
|
||||
typename MeanVarDataType,
|
||||
typename XGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_G,
|
||||
typename GetReduceCountPerThreadFunctor>
|
||||
__global__ void kernel_multiblock_welford_first_half(
|
||||
const XGridDesc_M_K x_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
|
||||
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
MeanVarDataType* const p_welford_mean,
|
||||
MeanVarDataType* const p_welford_variance,
|
||||
int32_t* const p_welford_count)
|
||||
{
|
||||
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
num_k_block_tile_iteration,
|
||||
p_x,
|
||||
p_welford_mean,
|
||||
p_welford_variance,
|
||||
p_welford_count);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename AccDataType,
|
||||
typename MeanVarDataType,
|
||||
typename XGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_G,
|
||||
typename GetReduceCountPerThreadFunctor,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcCountSrcVectorDim,
|
||||
index_t XSrcCountSrcVectorSize>
|
||||
struct GridwiseMultiblockWelfordFirstHalf
|
||||
{
|
||||
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
|
||||
(XSrcCountSrcVectorDim == 1 &&
|
||||
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
false>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
|
||||
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
MeanVarDataType* const p_welford_mean,
|
||||
MeanVarDataType* const p_welford_variance,
|
||||
int32_t* const p_welford_count)
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
|
||||
welford_count_thread_buf;
|
||||
|
||||
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / blkgroup_size;
|
||||
const index_t block_local_id = block_global_id % blkgroup_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
|
||||
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
XGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcCountSrcVectorDim,
|
||||
XSrcCountSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_welford_mean_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
MeanVarCountGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_welford_count_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
|
||||
int32_t,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
MeanVarCountGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
threadwise_welford.max_count_ =
|
||||
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
|
||||
BlockwiseWelford::Run(
|
||||
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
welford_mean_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
welford_mean_global_val_buf);
|
||||
|
||||
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
welford_var_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
welford_var_global_val_buf);
|
||||
|
||||
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
welford_count_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
welford_count_global_val_buf);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user