mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Fix for beta!=0 in reduce (#1440)
* fix for beta!=0 in reduce
* add reviewers suggestions
[ROCm/composable_kernel commit: b74d4d4d54]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
|
||||
auto invoker_ptr = reduce.MakeInvokerPointer();
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
int log_level = 0, cold_niters = 5, nrepeat = 50;
|
||||
if(beta != 0.0f)
|
||||
{
|
||||
std::cerr << "Warning: With beta != 0.0f there must be only one repeat for correct results "
|
||||
"since out memory is being overwritten."
|
||||
<< std::endl;
|
||||
cold_niters = 0;
|
||||
nrepeat = 1;
|
||||
}
|
||||
float avg_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, log_level, cold_niters, nrepeat});
|
||||
|
||||
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) +
|
||||
invariant_total_length * sizeof(InOutDataType);
|
||||
|
||||
Reference in New Issue
Block a user