mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Group norm (#417)
* Add groupnorm example by layernorm 1. Reference is not ready 2. shape of gamma and beta need to be fix * Let shape of gamma and beta can be same as x * Modify test, instance and client example * [What] Fix bug of layernorm for greater than 2 dimension. [Why] We need to get upper length from merge transform instead of embed transform. * Add reference for groupnorm * Fuse sigmoid after groupnorm * [What] Rename original layernorm into layernorm2d [Why] Prepare to add groupnorm using layernorm5d * clang-format * Add groupnorm test * Refine error message * Add groupnorm ckProfiler * Test groupnorm kernel from device_instance * update example * upadte profiler * Fix test naming * Fix argc number * Move descriptor and sweeponce to argument for quick debugging Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -22,7 +22,6 @@ template <typename XDataType,
|
||||
typename AccDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename GridDesc_M_K,
|
||||
typename GridDesc_K,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
@@ -30,7 +29,9 @@ template <typename XDataType,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t GammaSrcVectorDim,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorDim,
|
||||
index_t BetaSrcVectorSize,
|
||||
index_t YDstVectorDim,
|
||||
index_t YDstVectorSize,
|
||||
@@ -78,13 +79,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
|
||||
const GridDesc_K& gamma_grid_desc_k,
|
||||
const GridDesc_K& beta_grid_desc_k,
|
||||
const GridDesc_M_K& gamma_grid_desc_m_k,
|
||||
const GridDesc_M_K& beta_grid_desc_m_k,
|
||||
const GridDesc_M_K& y_grid_desc_m_k,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
@@ -111,11 +113,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true> gamma_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true>& beta_thread_buf =
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
gamma_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>& beta_thread_buf = gamma_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
y_thread_buf;
|
||||
|
||||
@@ -127,7 +132,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
mean_square_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& var_value_buf =
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& var_thread_buf =
|
||||
mean_square_thread_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
@@ -145,11 +150,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_K = Sequence<KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_k =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
@@ -169,27 +171,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
auto threadwise_gamma_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
|
||||
AccDataType,
|
||||
GridDesc_K,
|
||||
decltype(thread_buffer_desc_k),
|
||||
ThreadBufferLengths_K,
|
||||
Sequence<0>,
|
||||
0,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
gamma_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
|
||||
gamma_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
|
||||
AccDataType,
|
||||
GridDesc_K,
|
||||
decltype(thread_buffer_desc_k),
|
||||
ThreadBufferLengths_K,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BetaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
beta_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_beta_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
beta_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_y_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
@@ -212,9 +221,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
|
||||
// Copy x from Cache
|
||||
// one pass: fwd, second pass: bwd
|
||||
constexpr auto thread_copy_fwd_step_k = make_multi_index(SweepOnce ? 0 : K_BlockTileSize);
|
||||
constexpr auto thread_copy_bwd_step_k = make_multi_index(SweepOnce ? 0 : -K_BlockTileSize);
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k =
|
||||
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
|
||||
constexpr auto thread_copy_bwd_step_m_k =
|
||||
@@ -224,13 +230,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_gamma_global, gamma_grid_desc_k.GetElementSpaceSize());
|
||||
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_beta_global, beta_grid_desc_k.GetElementSpaceSize());
|
||||
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
// E(x), E[x^2], var(x)
|
||||
int reduce_length = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
|
||||
// FIXME: Should not hack the transform from deviceOP
|
||||
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
@@ -271,17 +278,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
|
||||
|
||||
// var(x) = E[x^2] - E[x]^2
|
||||
var_value_buf(I) =
|
||||
var_thread_buf(I) =
|
||||
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
|
||||
});
|
||||
|
||||
// y = (x - E[x]) / sqrt(var[x] + epsilon)
|
||||
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
|
||||
auto thread_copy_tail_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_k;
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_tail_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_tail_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
|
||||
reducedTiles = 0;
|
||||
@@ -296,10 +302,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
x_thread_buf);
|
||||
}
|
||||
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_k,
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_k,
|
||||
make_tuple(I0),
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
gamma_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
@@ -307,23 +313,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
|
||||
sqrt(var_value_buf(iM) + epsilon);
|
||||
sqrt(var_thread_buf(iM) + epsilon);
|
||||
|
||||
// gamma
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_k>{});
|
||||
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_beta_load.Run(beta_grid_desc_k,
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_k,
|
||||
make_tuple(I0),
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
beta_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
@@ -331,11 +335,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
|
||||
|
||||
// beta
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_k>{});
|
||||
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -346,8 +348,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
|
||||
y_global_val_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_bwd_step_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_bwd_step_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
|
||||
++reducedTiles;
|
||||
|
||||
@@ -19,7 +19,6 @@ template <typename XDataType,
|
||||
typename AccDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename GridDesc_M_K,
|
||||
typename GridDesc_K,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
@@ -27,7 +26,9 @@ template <typename XDataType,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t GammaSrcVectorDim,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorDim,
|
||||
index_t BetaSrcVectorSize,
|
||||
index_t YDstVectorDim,
|
||||
index_t YDstVectorSize,
|
||||
@@ -70,6 +71,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -77,7 +79,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
|
||||
int thread_k_cluster_id)
|
||||
{
|
||||
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
|
||||
// FIXME: Should not hack the transform from deviceOP
|
||||
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
|
||||
int kPerThread =
|
||||
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
|
||||
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
|
||||
@@ -94,8 +97,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
}
|
||||
|
||||
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
|
||||
const GridDesc_K& gamma_grid_desc_k,
|
||||
const GridDesc_K& beta_grid_desc_k,
|
||||
const GridDesc_M_K& gamma_grid_desc_m_k,
|
||||
const GridDesc_M_K& beta_grid_desc_m_k,
|
||||
const GridDesc_M_K& y_grid_desc_m_k,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
@@ -116,11 +119,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true> gamma_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true>& beta_thread_buf =
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
gamma_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>& beta_thread_buf = gamma_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
y_thread_buf;
|
||||
|
||||
@@ -137,11 +143,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_K = Sequence<KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_k =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
@@ -161,27 +164,34 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
auto threadwise_gamma_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
|
||||
AccDataType,
|
||||
GridDesc_K,
|
||||
decltype(thread_buffer_desc_k),
|
||||
ThreadBufferLengths_K,
|
||||
Sequence<0>,
|
||||
0,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
gamma_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
|
||||
gamma_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
|
||||
AccDataType,
|
||||
GridDesc_K,
|
||||
decltype(thread_buffer_desc_k),
|
||||
ThreadBufferLengths_K,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BetaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
beta_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_beta_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
beta_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_y_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
@@ -204,9 +214,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
|
||||
// Copy x from Cache
|
||||
// one pass: fwd, second pass: bwd
|
||||
constexpr auto thread_copy_fwd_step_k = make_multi_index(SweepOnce ? 0 : K_BlockTileSize);
|
||||
constexpr auto thread_copy_bwd_step_k = make_multi_index(SweepOnce ? 0 : -K_BlockTileSize);
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k =
|
||||
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
|
||||
constexpr auto thread_copy_bwd_step_m_k =
|
||||
@@ -216,10 +223,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_gamma_global, gamma_grid_desc_k.GetElementSpaceSize());
|
||||
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_beta_global, beta_grid_desc_k.GetElementSpaceSize());
|
||||
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
|
||||
@@ -250,11 +257,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
});
|
||||
|
||||
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
|
||||
auto thread_copy_tail_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_k;
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_tail_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_tail_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
@@ -268,10 +274,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
x_thread_buf);
|
||||
}
|
||||
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_k,
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_k,
|
||||
make_tuple(I0),
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
gamma_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
@@ -279,8 +285,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
|
||||
@@ -288,14 +292,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
|
||||
// gamma
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_k>{});
|
||||
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_beta_load.Run(beta_grid_desc_k,
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_k,
|
||||
make_tuple(I0),
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
beta_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
@@ -303,11 +307,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
|
||||
|
||||
// beta
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_k>{});
|
||||
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -318,8 +320,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
|
||||
y_global_val_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_bwd_step_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_bwd_step_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user