mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
[MIOpen][CK] Fix bwd weight conv test failures by disabling one block-GEMM V5 instance for 3D convs (#6421) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Due to compiler version update, there are test failures in the test target `test_grouped_convnd_bwd_weight` when running on `gfx90a`. There are four failing tests for FP16/BF16 that arise from a single kernel instance. As the problem is in the current develop branch, the test failures are blocking any PR merges into develop. An example of a failed CI runs is here: [http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/develop/558/pipeline/](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/develop/558/pipeline/). The underlying compiler problem is potentially the same as described in #6342 as the tests are passing for clang compiler version 20.0 and failing for clang compiler version 22.0. First attempt to fix this problem had to be reverted in #6400 because it broke MIOpen internal DB sync tests. ## Technical Details The root cause for the test failures are the block-GEMM V5 instances of `DeviceGroupedConvBwdWeight_Xdl_CShuffleV3` that have large tile size. The V5 pipeline uses double register buffer that in combination with large tile size causes high register pressure. The latest version of compiler handles the register spillage incorrectly for `gfx90a`, which cause the kernel to output incorrect results. The BF16/FP16 instances of `DeviceGroupedConvBwdWeight_Xdl_CShuffleV3` that do not use direct load for are divided into two groups - Base instances - Instances that result into high register usage (currently only one instance - one that causes the test failures). This division allows to disable only the V5 block-GEMM flavor of `DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<64, 128, 32, 32, Default, 8, 4, 1, 8, 8, 8, 8, 1, 1, 2>` for 3D convolutions on `gfx90a`. The selective disabling leaves the set of instances for 1D and 2D convolutions unaffected, and removes at runtime two V5 block-GEMM instances (`ConvBwdWeightDefault` and `ConvBwdWeightFilter1x1Stride1Pad0`) per data type (FP16/BF16) when the device is `gfx90a`. Because MIOpen uses CK's type string (provided by method `GetTypeString`) to identify the instances, the DB sync tests are expected to unaffected since there are still the V2 block-GEMM instances that result in the same type string (`DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<64, 128, 32, 32, Default, 8, 4, 1, 8, 8, 8, 8, 1, 1, 2>`). This expectation needs to be verified by running the MIOpen DB sync tests that are not part of the normal CK PR build. ## Test Plan Running all CI tests + the MIOpen internal DB sync tests is sufficient to verify the correctness of the code changes. ## Test Result Verified locally that the previously failing tests `TestGroupedConvndBwdWeight3d/4.Test3D` and `TestGroupedConvndBwdWeight3d/4.Test3D` have instance counts - 231 on `gfx90a` - 233 on `gfx942` and are currently passing. This confirms the expectation that two instances per data type should be disabled on `gfx90a`. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Ville Pietilä <>
164 lines
4.7 KiB
C++
164 lines
4.7 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#ifndef __HIPCC_RTC__
|
|
#include <string>
|
|
#include <string_view>
|
|
#include <hip/hip_runtime.h>
|
|
|
|
namespace ck {
|
|
|
|
constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u)
|
|
{
|
|
return str.empty() ? h
|
|
: fnv1a_hash(str.substr(1),
|
|
(h ^ static_cast<unsigned char>(str.front())) * 16777619u);
|
|
}
|
|
inline std::string get_device_name()
|
|
{
|
|
hipDeviceProp_t props{};
|
|
int device;
|
|
auto status = hipGetDevice(&device);
|
|
if(status != hipSuccess)
|
|
{
|
|
return std::string();
|
|
}
|
|
status = hipGetDeviceProperties(&props, device);
|
|
if(status != hipSuccess)
|
|
{
|
|
return std::string();
|
|
}
|
|
const std::string raw_name(props.gcnArchName);
|
|
const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
|
|
switch(fnv1a_hash(name))
|
|
{
|
|
// https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
|
|
case fnv1a_hash("Ellesmere"):
|
|
case fnv1a_hash("Baffin"):
|
|
case fnv1a_hash("RacerX"):
|
|
case fnv1a_hash("Polaris10"):
|
|
case fnv1a_hash("Polaris11"):
|
|
case fnv1a_hash("Tonga"):
|
|
case fnv1a_hash("Fiji"):
|
|
case fnv1a_hash("gfx800"):
|
|
case fnv1a_hash("gfx802"):
|
|
case fnv1a_hash("gfx804"): return "gfx803";
|
|
case fnv1a_hash("Vega10"):
|
|
case fnv1a_hash("gfx901"): return "gfx900";
|
|
case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030";
|
|
default: return name;
|
|
}
|
|
}
|
|
|
|
inline bool is_gfx90a() { return ck::get_device_name() == "gfx90a"; }
|
|
|
|
inline bool is_gfx12_supported()
|
|
{
|
|
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
|
|
}
|
|
|
|
inline bool is_gfx11_supported()
|
|
{
|
|
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
|
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" ||
|
|
ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" ||
|
|
ck::get_device_name() == "gfx1152" || ck::get_device_name() == "gfx1153";
|
|
}
|
|
|
|
inline bool is_xdl_supported()
|
|
{
|
|
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
|
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
|
|
is_gfx12_supported() || is_gfx11_supported();
|
|
}
|
|
|
|
template <typename ADataType,
|
|
typename BDataType,
|
|
index_t MPerXDL64,
|
|
index_t NPerXDL64,
|
|
index_t MPerXDL32 = MPerXDL64,
|
|
index_t NPerXDL32 = NPerXDL64>
|
|
inline bool is_xdl_wmma_supported()
|
|
{
|
|
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
|
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")
|
|
{
|
|
return true;
|
|
}
|
|
else if(is_gfx12_supported() || is_gfx11_supported())
|
|
{
|
|
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
|
|
{
|
|
return false;
|
|
}
|
|
if constexpr(sizeof(ADataType) > 2 || sizeof(BDataType) > 2)
|
|
{
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
else
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
|
|
inline bool is_lds_direct_load_supported()
|
|
{
|
|
// Check if direct loads from global memory to LDS are supported.
|
|
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx942" ||
|
|
ck::get_device_name() == "gfx950";
|
|
}
|
|
|
|
inline bool is_bf16_atomic_supported()
|
|
{
|
|
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
|
|
is_gfx12_supported();
|
|
}
|
|
|
|
inline bool is_gfx101_supported()
|
|
{
|
|
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
|
|
ck::get_device_name() == "gfx1012";
|
|
}
|
|
|
|
inline bool is_gfx103_supported()
|
|
{
|
|
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
|
|
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1033" ||
|
|
ck::get_device_name() == "gfx1034" || ck::get_device_name() == "gfx1035" ||
|
|
ck::get_device_name() == "gfx1036";
|
|
}
|
|
|
|
inline bool is_wmma_supported()
|
|
{
|
|
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
|
|
}
|
|
|
|
inline bool is_tf32_supported()
|
|
{
|
|
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
|
}
|
|
|
|
inline int __host__ get_lds_size()
|
|
{
|
|
int device = 0;
|
|
int result = 0;
|
|
auto status = hipGetDevice(&device);
|
|
if(status == hipSuccess)
|
|
{
|
|
status = hipDeviceGetAttribute(&result, hipDeviceAttributeMaxSharedMemoryPerBlock, device);
|
|
if(status == hipSuccess)
|
|
{
|
|
return result;
|
|
}
|
|
}
|
|
|
|
return 64 * 1024;
|
|
}
|
|
|
|
} // namespace ck
|
|
#endif
|