mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Fix for beta!=0 in reduce (#1440)
* fix for beta!=0 in reduce * add reviewers suggestions
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
|
||||
auto invoker_ptr = reduce.MakeInvokerPointer();
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
int log_level = 0, cold_niters = 5, nrepeat = 50;
|
||||
if(beta != 0.0f)
|
||||
{
|
||||
std::cerr << "Warning: With beta != 0.0f there must be only one repeat for correct results "
|
||||
"since out memory is being overwritten."
|
||||
<< std::endl;
|
||||
cold_niters = 0;
|
||||
nrepeat = 1;
|
||||
}
|
||||
float avg_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, log_level, cold_niters, nrepeat});
|
||||
|
||||
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) +
|
||||
invariant_total_length * sizeof(InOutDataType);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -258,7 +258,7 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if(block_group_size == 0 && !float_equal_zero{}(beta_values[iR]))
|
||||
if(!float_equal_zero{}(beta_values[iR]))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -244,7 +244,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if(block_group_size == 0 && !float_equal_zero{}(beta))
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
Reference in New Issue
Block a user