mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
[rocm-libraries] ROCm/rocm-libraries#6611 (commit 5375c0f)
[CK_TILE] Preserve input strides in EightWaves async-load descriptor (#6611) `MakeAsyncLoadADramWindow` in `GemmPipelineAgBgCrCompAsyncEightWavesPolicy` was rebuilding the 6D view descriptor with `make_naive_tensor_descriptor_packed`, which synthesizes strides from lengths and assumes a dense layout. When the input view's leading-dim stride is larger than its inner length (non-packed memory layout), the resulting tile window stepped through memory at the wrong stride. Compose the unmerge transforms on top of the input view's existing descriptor instead, so the actual runtime strides are preserved and the correct `element_space_size` is inherited for bounds checking. ## Test Plan Added an unit test showing the problem. ## Test Result The new test fails before fixes and passes after. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
9d34174ac2
commit
cbfb3e242e
@@ -86,6 +86,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_eightwaves PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_eightwaves_padded_stride
|
||||
test_gemm_quant_abquant_eightwaves_padded_stride.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_eightwaves_padded_stride PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# ABQuant split-K tests
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode
|
||||
test_gemm_quant_abquant_splitk_decode.cpp
|
||||
@@ -281,6 +286,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_tile_gemm_quant_abquant_a4w4_padding
|
||||
test_tile_gemm_quant_abquant_a4w4_preshuffle
|
||||
test_tile_gemm_quant_abquant_eightwaves
|
||||
test_tile_gemm_quant_abquant_eightwaves_padded_stride
|
||||
# ABQuant split-K tests
|
||||
test_tile_gemm_quant_abquant_splitk_decode
|
||||
test_tile_gemm_quant_abquant_splitk_prefill
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Regression test for the EightWaves ABQuant pipeline on a B tensor whose
|
||||
// leading-dim stride is larger than the packed value. The async B-load
|
||||
// descriptor in the EightWaves policy must be built from the input view's
|
||||
// real strides so that the kernel addresses B correctly when stride_B is
|
||||
// larger than the inner length (e.g. row-aligned weight padding).
|
||||
|
||||
#include "test_gemm_quant_common.hpp"
|
||||
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
#ifdef CK_GFX950_SUPPORT
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantEightWavesPaddedStrideTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWaves, GroupSize1D_128, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWavesPaddedStrideTypes);
|
||||
|
||||
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedPaddedBStrideTest)
|
||||
{
|
||||
// 256-byte row alignment for FP8 -> 256 elements of leading-dim padding.
|
||||
constexpr ck_tile::index_t k_batch = 1;
|
||||
constexpr ck_tile::index_t stride_B_pad = 256;
|
||||
this->run_test_with_validation(1024, 1024, 1024, k_batch, stride_B_pad);
|
||||
}
|
||||
#endif
|
||||
@@ -1038,12 +1038,17 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
void run_test_with_validation(ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t k_batch = 1)
|
||||
ck_tile::index_t k_batch = 1,
|
||||
ck_tile::index_t stride_B_pad = 0)
|
||||
{
|
||||
const ck_tile::index_t stride_A =
|
||||
ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{}));
|
||||
// stride_B_pad lets a test exercise a B tensor whose leading-dim stride is
|
||||
// larger than the packed value (e.g. row-aligned padding). The host tensor,
|
||||
// device buffer, and kernel args are all built with this padded stride so
|
||||
// the kernel must honor the runtime stride to address B correctly.
|
||||
const ck_tile::index_t stride_B =
|
||||
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{}));
|
||||
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})) + stride_B_pad;
|
||||
const ck_tile::index_t stride_C =
|
||||
ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{}));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user