[rocm-libraries] ROCm/rocm-libraries#6611 (commit 5375c0f)

[CK_TILE] Preserve input strides in EightWaves async-load
 descriptor (#6611)

`MakeAsyncLoadADramWindow` in
`GemmPipelineAgBgCrCompAsyncEightWavesPolicy` was rebuilding the 6D view
descriptor with `make_naive_tensor_descriptor_packed`, which synthesizes
strides from lengths and assumes a dense layout. When the input view's
leading-dim stride is larger than its inner length (non-packed memory
layout), the resulting tile window stepped through memory at the wrong
stride.

Compose the unmerge transforms on top of the input view's existing
descriptor instead, so the actual runtime strides are preserved and the
correct `element_space_size` is inherited for bounds checking.

## Test Plan

Added an unit test showing the problem.

## Test Result

The new test fails before fixes and passes after.

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-04-22 10:52:59 +00:00
committed by assistant-librarian[bot]
parent 9d34174ac2
commit cbfb3e242e
4 changed files with 53 additions and 6 deletions

View File

@@ -176,10 +176,15 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
const index_t M0 = integer_divide_ceil(rows, M1);
const auto row_lens = make_tuple(M0, number<M1>{});
const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_0 = decltype(d0)( // set correct size (without padding)
d0.get_transforms(),
tensor_view_tmp.get_tensor_descriptor().get_element_space_size());
// Build the 6D view by composing unmerge transforms on top of the
// input view's existing descriptor. This preserves the input's actual
// strides (so a non-packed leading-dim stride is honored) and inherits
// its element_space_size for bounds checking.
const auto desc_0 = transform_tensor_descriptor(
tensor_view_tmp.get_tensor_descriptor(),
make_tuple(make_unmerge_transform(row_lens), make_unmerge_transform(col_lens)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{}));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),