mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Add CountDataType as template parameter in blockwise_welford
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -35,10 +35,11 @@ struct BlockwiseWelford
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
template <typename CountDataType>
|
||||
__device__ static inline void
|
||||
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
|
||||
Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
|
||||
{
|
||||
int count = count_a + count_b;
|
||||
CountDataType count = count_a + count_b;
|
||||
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
|
||||
T delta = mean_b - mean_a;
|
||||
mean_a += delta * count_b_over_count;
|
||||
@@ -46,11 +47,12 @@ struct BlockwiseWelford
|
||||
count_a = count;
|
||||
}
|
||||
|
||||
__device__ static void Run(T& mean_value, T& var_value, int& count)
|
||||
template <typename CountDataType>
|
||||
__device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
|
||||
{
|
||||
__shared__ T mean_block_buf[BlockSize];
|
||||
__shared__ T var_block_buf[BlockSize];
|
||||
__shared__ int count_block_buf[BlockSize];
|
||||
__shared__ CountDataType count_block_buf[BlockSize];
|
||||
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
@@ -76,13 +78,13 @@ struct BlockwiseWelford
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
T mean1 = mean_block_buf[offset1];
|
||||
T var1 = var_block_buf[offset1];
|
||||
int count1 = count_block_buf[offset1];
|
||||
T mean1 = mean_block_buf[offset1];
|
||||
T var1 = var_block_buf[offset1];
|
||||
CountDataType count1 = count_block_buf[offset1];
|
||||
|
||||
T mean2 = mean_block_buf[offset2];
|
||||
T var2 = var_block_buf[offset2];
|
||||
int count2 = count_block_buf[offset2];
|
||||
T mean2 = mean_block_buf[offset2];
|
||||
T var2 = var_block_buf[offset2];
|
||||
CountDataType count2 = count_block_buf[offset2];
|
||||
|
||||
Merge(mean1, var1, count1, mean2, var2, count2);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user