[rocm-libraries] ROCm/rocm-libraries#7925 (commit a8f0845)

[CK] Fix gfx950 AITER Sync Regressions
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

Fixes three gfx950 regressions in the AITER downstream CI that surfaced
after the internal/gfx1250 re-sync (ROCm/rocm-libraries#6978):

> **Companion aiter PR:** ROCm/aiter#3392 — host-side adaptations
(`Kernel::BlockSize()` `constexpr` drops, blockscale `KBatch=1` clamp)
plus the CK submodule bump used to validate these fixes together.

- **FlyDSL MoE AOT cache miss** — the AITER MoE tests run with
`check_aot_cache=True` and fail on any FlyDSL JIT cache miss, but the CI
never pre-compiles the FlyDSL MoE kernels, so gfx950 always misses.
Pre-compile them at the start of the AITER test stage.
- **`buffer.load.lds.v4i32` link error** — ROCm/rocm-libraries#6978
reintroduced a clang-version guard mapping
`llvm.amdgcn.raw.buffer.load.lds` to a `.v4i32`-suffixed name. That name
exists in no LLVM (the rsrc operand is a fixed, non-overloaded `<4 x
i32>`, so the intrinsic is never type-mangled), so gfx950 4-DWORD
direct-to-LDS (e.g. fp4 MoE bpreshuffle) fails to link with `lld:
undefined symbol: llvm.amdgcn.raw.buffer.load.lds.v4i32`. Use the
canonical plain name unconditionally.
- **mixed-precision flatmm warp-GEMM call** — ROCm/rocm-libraries#6978
generalized the scaled `WarpGemmImpl::operator()` from a fixed `<index_t
opselA, index_t opselB>` signature to a variadic `<typename... Params>`
one and updated the `mx_flatmm` pipeline to pass the op-selectors as
`OpSelA<>`/`OpSelB<>` types, but missed the mixed-precision flatmm
pipeline (`F8xMXF4`/`F16xMXF4`), which still passed raw integer
op-selectors. These no longer bind to `typename... Params` (`error: no
matching member function for call to 'operator()'`), breaking
compilation of the fp8/bf16 × fp4 cktile MoE gemm1 instances on gfx950
(aiter `test_moe_2stage`). Wrap the op-selectors in
`OpSelA<>`/`OpSelB<>`.

## Changes

- `Jenkinsfile`: pre-compile the FlyDSL MoE AOT cache (`python3
aiter/aot/flydsl/moe.py`) before the AITER tests.
- `include/ck/utility/amd_buffer_addressing_builtins.hpp` and
`include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp`: drop the
`__clang_major__` guard and always use
`__asm("llvm.amdgcn.raw.buffer.load.lds")`. The plain name is the
canonical one for all sizes including the gfx950 16-byte form, as the
upstream LLVM gfx950 tests confirm.
-
`include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp`:
wrap the warp-GEMM op-selectors in `OpSelA<>`/`OpSelB<>` at the five
call sites, matching the `mx_flatmm` pipeline.

## Test plan

Validated via CI.
This commit is contained in:
Yi DING
2026-06-03 02:09:05 +00:00
committed by assistant-librarian[bot]
parent 5720589311
commit 01bd52bdb5
4 changed files with 9 additions and 27 deletions

View File

@@ -830,16 +830,6 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
}
// Direct loads from global to LDS.
#if __clang_major__ >= 21 && __clang_major__ < 23
__device__ void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32");
#else
__device__ void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
@@ -848,7 +838,6 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
#endif
#ifndef __HIPCC_RTC__
template <typename T, index_t NumElemsPerThread>

View File

@@ -1381,16 +1381,6 @@ CK_TILE_DEVICE_EXTERN double llvm_amdgcn_raw_buffer_atomic_max_fp64(
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32");
// Direct loads from global to LDS.
#if __clang_major__ >= 21 && __clang_major__ < 23
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
as3_uint32_ptr lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32");
#else
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
as3_uint32_ptr lds_ptr,
@@ -1399,7 +1389,6 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
#endif
template <unsigned num_dwords, bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem,

View File

@@ -1982,7 +1982,7 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1
// warp GEMM
WG{}.template
// operator()<MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
operator()<OpSelA<ikxdl * MXdlPack + imxdl>, OpSelB<ikxdl * NXdlPack + inxdl>>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} +
@@ -2092,7 +2092,7 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1
// warp GEMM
WG{}.template
// operator()<MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
operator()<OpSelA<ikxdl * MXdlPack + imxdl>, OpSelB<ikxdl * NXdlPack + inxdl>>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} +
@@ -2214,7 +2214,8 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}.template operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
WG{}.template
operator()<OpSelA<ikxdl * MXdlPack + imxdl>, OpSelB<ikxdl * NXdlPack + inxdl>>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} +
@@ -2283,7 +2284,8 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}.template operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
WG{}.template
operator()<OpSelA<ikxdl * MXdlPack + imxdl>, OpSelB<ikxdl * NXdlPack + inxdl>>(
// operator()<MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
@@ -2346,7 +2348,7 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1
// warp GEMM
WG{}.template
// operator()<MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
operator()<OpSelA<ikxdl * MXdlPack + imxdl>, OpSelB<ikxdl * NXdlPack + inxdl>>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} +

View File

@@ -1180,6 +1180,8 @@ def getPytorchTestsCmds() {
}
def getAiterTestsCmds() {
return [
// Pre-compile FlyDSL MoE AOT cache before the tests.
"cd /home/jenkins/workspace/aiter && python3 aiter/aot/flydsl/moe.py",
"python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py",
"python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py",
"python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py",