mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add Normalization splitk instances (#829)
* Add normalization splitK to layernorm and groupnorm instances * Fix bug of GetKPerThread() * Refine naming * clang format
This commit is contained in:
@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st
|
||||
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
|
||||
|
||||
__device__ static int
|
||||
GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
|
||||
GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
|
||||
{
|
||||
bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
|
||||
|
||||
if(is_rightmost_block)
|
||||
{
|
||||
int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
|
||||
int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock;
|
||||
int kPerThread =
|
||||
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
|
||||
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
|
||||
int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
|
||||
int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
|
||||
int kPerThread = kRightmostBlock < K_BlockTileSize
|
||||
? 0
|
||||
: KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
|
||||
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
|
||||
|
||||
if(kPerBlockTail > 0)
|
||||
{
|
||||
@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st
|
||||
}
|
||||
else
|
||||
{
|
||||
int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
|
||||
int kPerBlock = math::integer_divide_ceil(k, kGridSize);
|
||||
return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
|
||||
}
|
||||
}
|
||||
@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st
|
||||
auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
|
||||
threadwise_welford.max_count_ =
|
||||
GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id);
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
|
||||
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
|
||||
kRaw,
|
||||
k_grid_size,
|
||||
block_k_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
|
||||
Reference in New Issue
Block a user