mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Grouped convolution backward weight special vector size loads (#1772)
* Grouped convolution backward weight special vector size loads * Instnaces and tests * Fixes * Add 7 and 13 special cases * fix comments * Fix * Fix2 * fixes * fix atomic add bf16
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -1558,14 +1558,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
if(!(arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
const bool is_w_pad_zero = arg.input_left_pads_[NDimSpatial - 1] == 0 &&
|
||||
arg.input_right_pads_[NDimSpatial - 1] == 0;
|
||||
const auto X = arg.filter_spatial_lengths_[NDimSpatial - 1];
|
||||
const bool XC_access_allowed = arg.Conv_G_ == 1 &&
|
||||
(arg.Conv_C_ * X) % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
is_w_pad_zero;
|
||||
|
||||
if(!((arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 || XC_access_allowed) &&
|
||||
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1))
|
||||
if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1 &&
|
||||
NumGroupsToMerge > 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1))
|
||||
if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1 &&
|
||||
NumGroupsToMerge > 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
Reference in New Issue
Block a user