[rocm-libraries] ROCm/rocm-libraries#6611 (commit 5375c0f)

[CK_TILE] Preserve input strides in EightWaves async-load
 descriptor (#6611)

`MakeAsyncLoadADramWindow` in
`GemmPipelineAgBgCrCompAsyncEightWavesPolicy` was rebuilding the 6D view
descriptor with `make_naive_tensor_descriptor_packed`, which synthesizes
strides from lengths and assumes a dense layout. When the input view's
leading-dim stride is larger than its inner length (non-packed memory
layout), the resulting tile window stepped through memory at the wrong
stride.

Compose the unmerge transforms on top of the input view's existing
descriptor instead, so the actual runtime strides are preserved and the
correct `element_space_size` is inherited for bounds checking.

## Test Plan

Added an unit test showing the problem.

## Test Result

The new test fails before fixes and passes after.

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-04-22 10:52:59 +00:00
committed by assistant-librarian[bot]
parent 9d34174ac2
commit cbfb3e242e
4 changed files with 53 additions and 6 deletions

View File

@@ -86,6 +86,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_abquant_eightwaves PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_eightwaves_padded_stride
test_gemm_quant_abquant_eightwaves_padded_stride.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_eightwaves_padded_stride PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# ABQuant split-K tests
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode
test_gemm_quant_abquant_splitk_decode.cpp
@@ -281,6 +286,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
test_tile_gemm_quant_abquant_a4w4_padding
test_tile_gemm_quant_abquant_a4w4_preshuffle
test_tile_gemm_quant_abquant_eightwaves
test_tile_gemm_quant_abquant_eightwaves_padded_stride
# ABQuant split-K tests
test_tile_gemm_quant_abquant_splitk_decode
test_tile_gemm_quant_abquant_splitk_prefill

View File

@@ -0,0 +1,31 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Regression test for the EightWaves ABQuant pipeline on a B tensor whose
// leading-dim stride is larger than the packed value. The async B-load
// descriptor in the EightWaves policy must be built from the input view's
// real strides so that the kernel addresses B correctly when stride_B is
// larger than the inner length (e.g. row-aligned weight padding).
#include "test_gemm_quant_common.hpp"
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
#ifdef CK_GFX950_SUPPORT
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantEightWavesPaddedStrideTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWaves, GroupSize1D_128, GroupSize2D128N, ColumnMajor>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWavesPaddedStrideTypes);
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedPaddedBStrideTest)
{
// 256-byte row alignment for FP8 -> 256 elements of leading-dim padding.
constexpr ck_tile::index_t k_batch = 1;
constexpr ck_tile::index_t stride_B_pad = 256;
this->run_test_with_validation(1024, 1024, 1024, k_batch, stride_B_pad);
}
#endif

View File

@@ -1038,12 +1038,17 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
void run_test_with_validation(ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t k_batch = 1)
ck_tile::index_t k_batch = 1,
ck_tile::index_t stride_B_pad = 0)
{
const ck_tile::index_t stride_A =
ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{}));
// stride_B_pad lets a test exercise a B tensor whose leading-dim stride is
// larger than the packed value (e.g. row-aligned padding). The host tensor,
// device buffer, and kernel args are all built with this padded stride so
// the kernel must honor the runtime stride to address B correctly.
const ck_tile::index_t stride_B =
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{}));
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})) + stride_B_pad;
const ck_tile::index_t stride_C =
ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{}));