diff --git a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp index a3813ea248..7506f7072b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -35,10 +35,11 @@ struct BlockwiseWelford static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + template __device__ static inline void - Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) + Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b) { - int count = count_a + count_b; + CountDataType count = count_a + count_b; T count_b_over_count = count == 0 ? type_convert(0) : type_convert(count_b) / count; T delta = mean_b - mean_a; mean_a += delta * count_b_over_count; @@ -46,11 +47,12 @@ struct BlockwiseWelford count_a = count; } - __device__ static void Run(T& mean_value, T& var_value, int& count) + template + __device__ static void Run(T& mean_value, T& var_value, CountDataType& count) { __shared__ T mean_block_buf[BlockSize]; __shared__ T var_block_buf[BlockSize]; - __shared__ int count_block_buf[BlockSize]; + __shared__ CountDataType count_block_buf[BlockSize]; constexpr auto cluster_len_shift = get_shift(); @@ -76,13 +78,13 @@ struct BlockwiseWelford index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + make_tuple(0, indOffset)); - T mean1 = mean_block_buf[offset1]; - T var1 = var_block_buf[offset1]; - int count1 = count_block_buf[offset1]; + T mean1 = mean_block_buf[offset1]; + T var1 = var_block_buf[offset1]; + CountDataType count1 = count_block_buf[offset1]; - T mean2 = mean_block_buf[offset2]; - T var2 = var_block_buf[offset2]; - int count2 = count_block_buf[offset2]; + T mean2 = mean_block_buf[offset2]; + T var2 = var_block_buf[offset2]; + CountDataType count2 = count_block_buf[offset2]; Merge(mean1, var1, count1, mean2, var2, count2);