From cbfb3e242ecd51aa14dde9a4ae6f581a2f52c4cf Mon Sep 17 00:00:00 2001 From: Sami Remes <181322991+samremes@users.noreply.github.com> Date: Wed, 22 Apr 2026 10:52:59 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6611 (commit 5375c0f) [CK_TILE] Preserve input strides in EightWaves async-load descriptor (#6611) `MakeAsyncLoadADramWindow` in `GemmPipelineAgBgCrCompAsyncEightWavesPolicy` was rebuilding the 6D view descriptor with `make_naive_tensor_descriptor_packed`, which synthesizes strides from lengths and assumes a dense layout. When the input view's leading-dim stride is larger than its inner length (non-packed memory layout), the resulting tile window stepped through memory at the wrong stride. Compose the unmerge transforms on top of the input view's existing descriptor instead, so the actual runtime strides are preserved and the correct `element_space_size` is inherited for bounds checking. ## Test Plan Added an unit test showing the problem. ## Test Result The new test fails before fixes and passes after. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- ...ag_bg_cr_comp_async_eight_waves_policy.hpp | 13 +++++--- test/ck_tile/gemm_block_scale/CMakeLists.txt | 6 ++++ ...quant_abquant_eightwaves_padded_stride.cpp | 31 +++++++++++++++++++ .../test_gemm_quant_fixtures.hpp | 9 ++++-- 4 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwaves_padded_stride.cpp diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp index 29991197cd..1e1f525c3b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp @@ -176,10 +176,15 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy const index_t M0 = integer_divide_ceil(rows, M1); const auto row_lens = make_tuple(M0, number{}); - const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); - const auto desc_0 = decltype(d0)( // set correct size (without padding) - d0.get_transforms(), - tensor_view_tmp.get_tensor_descriptor().get_element_space_size()); + // Build the 6D view by composing unmerge transforms on top of the + // input view's existing descriptor. This preserves the input's actual + // strides (so a non-packed leading-dim stride is honored) and inherits + // its element_space_size for bounds checking. + const auto desc_0 = transform_tensor_descriptor( + tensor_view_tmp.get_tensor_descriptor(), + make_tuple(make_unmerge_transform(row_lens), make_unmerge_transform(col_lens)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{})); const auto desc_1 = transform_tensor_descriptor( desc_0, make_tuple(make_pass_through_transform(M0), diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 9f77cf01d7..21d34f7b34 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -86,6 +86,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_eightwaves PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_eightwaves_padded_stride + test_gemm_quant_abquant_eightwaves_padded_stride.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_eightwaves_padded_stride PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # ABQuant split-K tests add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode test_gemm_quant_abquant_splitk_decode.cpp @@ -281,6 +286,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") test_tile_gemm_quant_abquant_a4w4_padding test_tile_gemm_quant_abquant_a4w4_preshuffle test_tile_gemm_quant_abquant_eightwaves + test_tile_gemm_quant_abquant_eightwaves_padded_stride # ABQuant split-K tests test_tile_gemm_quant_abquant_splitk_decode test_tile_gemm_quant_abquant_splitk_prefill diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwaves_padded_stride.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwaves_padded_stride.cpp new file mode 100644 index 0000000000..28b7811af3 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwaves_padded_stride.cpp @@ -0,0 +1,31 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Regression test for the EightWaves ABQuant pipeline on a B tensor whose +// leading-dim stride is larger than the packed value. The async B-load +// descriptor in the EightWaves policy must be built from the input view's +// real strides so that the kernel addresses B correctly when stride_B is +// larger than the inner length (e.g. row-aligned weight padding). + +#include "test_gemm_quant_common.hpp" + +using GroupSize2D128N = ck_tile::QuantGroupShape>; +#ifdef CK_GFX950_SUPPORT +// Tuple format: +// clang-format off +using ABQuantEightWavesPaddedStrideTypes = ::testing::Types< + std::tuple +>; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWavesPaddedStrideTypes); + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedPaddedBStrideTest) +{ + // 256-byte row alignment for FP8 -> 256 elements of leading-dim padding. + constexpr ck_tile::index_t k_batch = 1; + constexpr ck_tile::index_t stride_B_pad = 256; + this->run_test_with_validation(1024, 1024, 1024, k_batch, stride_B_pad); +} +#endif diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 8fbda4a3ce..e5731c5caa 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -1038,12 +1038,17 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBaseis_row_major(ALayout{})); + // stride_B_pad lets a test exercise a B tensor whose leading-dim stride is + // larger than the packed value (e.g. row-aligned padding). The host tensor, + // device buffer, and kernel args are all built with this padded stride so + // the kernel must honor the runtime stride to address B correctly. const ck_tile::index_t stride_B = - ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})); + ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})) + stride_B_pad; const ck_tile::index_t stride_C = ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{}));