[CK] Disable compilation of problematic bwd weight conv instances for gfx90a (#6343)

## Motivation

Due to compiler version update, there are test failures in the test
suite `test_grouped_convnd_bwd_weight` when running on `gfx90a`. There
are four failing tests for FP16/BF16 that arise from a single kernel
instance. As the problem is in the current `develop` branch, the test
failures are blocking any PR merges into `develop`. An example of a
failed CI runs is here:
[http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/develop/558/pipeline/](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/develop/558/pipeline/).
The underlying compiler problem is potentially the same as described in
#6342 as tests are passing for clang compiler version 20.0 and failing
for clang compiler version 22.0.

## Technical Details

This PR disables the compilation of the problematic bwd weight conv
instance for `gfx90a` by adding a new CMake flag `CK_USE_GFX90A` that
allows us to detect when we are compiling for `gfx90a`. Using the new
CMake flag, compilation of instance
`DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<64, 128, 32, 32, Default, 8,
4, 1, 8, 8, 8, 8, 1, 1, 2>` is disabled for `gfx90a`.

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2026-04-13 14:40:27 +03:00
committed by GitHub
parent 9fe98c864f
commit 1b2a619107
4 changed files with 42 additions and 0 deletions

View File

@@ -274,6 +274,11 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950" AND NOT FORCE_DISABLE_XDL)
add_definitions(-DCK_USE_GFX950)
set(CK_USE_GFX950 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" AND NOT FORCE_DISABLE_XDL)
add_definitions(-DCK_USE_GFX90A)
set(CK_USE_GFX90A "ON")
endif()
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
set(CK_TILE_USE_WMMA 0)

View File

@@ -209,6 +209,17 @@
#endif
#endif
// workaround for AMDGPU compiler VGPR aliasing bug in dropout codegen (ROCm >= 7.12)
// Philox RNG VGPR parameters get aliased under high register pressure (d256 tile).
// fp16 is affected; bf16 is not (different type conversion codegen path).
#ifndef CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE
#if(HIP_VERSION_MAJOR == 7 && HIP_VERSION_MINOR >= 12) || (HIP_VERSION_MAJOR > 7)
#define CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE 1
#else
#define CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif

View File

@@ -95,7 +95,16 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
// Problematic instance on gfx90a - accuracy tests fail for 3D bwd weight conv as the instance produces incorrect results.
// The problem occurs at least for compiler version
// 22.0.0git (https://github.com/ROCm/llvm-project.git 2de9eb6063dd56b109cf139a75550b7b06808273+PATCHED:9a6ac45c97a1e511db838c5b46257324d2de1780)
// Older compilers from the 20.0 family produce correct results.
#if defined(CK_USE_GFX90A)
#else
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
#endif
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>
// clang-format on
@@ -168,7 +177,16 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances = std::tupl
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
// Problematic instance on gfx90a - accuracy tests fail for 3D bwd weight conv as the instance produces incorrect results.
// The problem occurs at least for compiler version
// 22.0.0git (https://github.com/ROCm/llvm-project.git 2de9eb6063dd56b109cf139a75550b7b06808273+PATCHED:9a6ac45c97a1e511db838c5b46257324d2de1780)
// Older compilers from the 20.0 family produce correct results.
#if defined(CK_USE_GFX90A)
#else
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
#endif
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>
//clang-format on

View File

@@ -601,6 +601,14 @@ TEST_P(Dropout, DataTypeConfig)
auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs;
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
#if CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp16>)
{
if(hdim_q > 128 && mode == mode_enum::batch)
GTEST_SKIP() << "Skipped: fp16 dropout d256 batch — compiler bug (ROCm >= 7.12)";
}
#endif
auto result = fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,