mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Grouped convolution backward weight special vector size loads (#1772)
* Grouped convolution backward weight special vector size loads * Instnaces and tests * Fixes * Add 7 and 13 special cases * fix comments * Fix * Fix2 * fixes * fix atomic add bf16
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -89,6 +89,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
out_device_buf.ToDevice(output.mData.data());
|
||||
|
||||
float max_accumulated_value = 0;
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
@@ -114,6 +115,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
max_accumulated_value =
|
||||
*std::max_element(weight_host_result.mData.begin(), weight_host_result.mData.end());
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
@@ -237,7 +240,39 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
{
|
||||
wei_device_buf.FromDevice(weight_device_result.mData.data());
|
||||
|
||||
bool pass = ck::utils::check_err(weight_device_result, weight_host_result);
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeA) < sizeof(ComputeTypeB),
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
const index_t num_accums = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums_split_k = split_k_list[split_k_id];
|
||||
// Calculate thresholds
|
||||
auto rtol =
|
||||
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
auto atol =
|
||||
ck::utils::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
max_accumulated_value / num_accums_split_k,
|
||||
num_accums / num_accums_split_k);
|
||||
// Calculate error due to split_k accumulation
|
||||
auto rtol_split_k =
|
||||
ck::utils::get_relative_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
num_accums_split_k);
|
||||
auto atol_split_k =
|
||||
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
max_accumulated_value, num_accums_split_k);
|
||||
// Use higher threshold
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
bool pass = ck::utils::check_err(weight_device_result,
|
||||
weight_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user