diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp index 29991197cd..1e1f525c3b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp @@ -176,10 +176,15 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy const index_t M0 = integer_divide_ceil(rows, M1); const auto row_lens = make_tuple(M0, number{}); - const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); - const auto desc_0 = decltype(d0)( // set correct size (without padding) - d0.get_transforms(), - tensor_view_tmp.get_tensor_descriptor().get_element_space_size()); + // Build the 6D view by composing unmerge transforms on top of the + // input view's existing descriptor. This preserves the input's actual + // strides (so a non-packed leading-dim stride is honored) and inherits + // its element_space_size for bounds checking. + const auto desc_0 = transform_tensor_descriptor( + tensor_view_tmp.get_tensor_descriptor(), + make_tuple(make_unmerge_transform(row_lens), make_unmerge_transform(col_lens)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{})); const auto desc_1 = transform_tensor_descriptor( desc_0, make_tuple(make_pass_through_transform(M0), diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 9f77cf01d7..21d34f7b34 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -86,6 +86,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_eightwaves PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_eightwaves_padded_stride + test_gemm_quant_abquant_eightwaves_padded_stride.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_eightwaves_padded_stride PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # ABQuant split-K tests add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode test_gemm_quant_abquant_splitk_decode.cpp @@ -281,6 +286,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") test_tile_gemm_quant_abquant_a4w4_padding test_tile_gemm_quant_abquant_a4w4_preshuffle test_tile_gemm_quant_abquant_eightwaves + test_tile_gemm_quant_abquant_eightwaves_padded_stride # ABQuant split-K tests test_tile_gemm_quant_abquant_splitk_decode test_tile_gemm_quant_abquant_splitk_prefill diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwaves_padded_stride.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwaves_padded_stride.cpp new file mode 100644 index 0000000000..28b7811af3 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwaves_padded_stride.cpp @@ -0,0 +1,31 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Regression test for the EightWaves ABQuant pipeline on a B tensor whose +// leading-dim stride is larger than the packed value. The async B-load +// descriptor in the EightWaves policy must be built from the input view's +// real strides so that the kernel addresses B correctly when stride_B is +// larger than the inner length (e.g. row-aligned weight padding). + +#include "test_gemm_quant_common.hpp" + +using GroupSize2D128N = ck_tile::QuantGroupShape>; +#ifdef CK_GFX950_SUPPORT +// Tuple format: +// clang-format off +using ABQuantEightWavesPaddedStrideTypes = ::testing::Types< + std::tuple +>; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWavesPaddedStrideTypes); + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedPaddedBStrideTest) +{ + // 256-byte row alignment for FP8 -> 256 elements of leading-dim padding. + constexpr ck_tile::index_t k_batch = 1; + constexpr ck_tile::index_t stride_B_pad = 256; + this->run_test_with_validation(1024, 1024, 1024, k_batch, stride_B_pad); +} +#endif diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 8fbda4a3ce..e5731c5caa 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -1038,12 +1038,17 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBaseis_row_major(ALayout{})); + // stride_B_pad lets a test exercise a B tensor whose leading-dim stride is + // larger than the packed value (e.g. row-aligned padding). The host tensor, + // device buffer, and kernel args are all built with this padded stride so + // the kernel must honor the runtime stride to address B correctly. const ck_tile::index_t stride_B = - ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})); + ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})) + stride_B_pad; const ck_tile::index_t stride_C = ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{}));