Add Normalization splitk instances (#829)

* Add normalization splitK to layernorm and groupnorm instances

* Fix bug of GetKPerThread()

* Refine naming

* clang format
This commit is contained in:
rocking
2023-08-12 01:31:31 +08:00
committed by GitHub
parent a5343db00d
commit 03b8119e2e
13 changed files with 120 additions and 11 deletions

View File

@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static int
GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
{
bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
if(is_rightmost_block)
{
int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock;
int kPerThread =
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
int kPerThread = kRightmostBlock < K_BlockTileSize
? 0
: KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0)
{
@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st
}
else
{
int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
int kPerBlock = math::integer_divide_ceil(k, kGridSize);
return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
}
}
@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st
auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
threadwise_welford.max_count_ =
GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id);
auto threadwise_welford = ThreadwiseWelford();
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
kRaw,
k_grid_size,
block_k_cluster_id,
thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);