mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[rocm-libraries] ROCm/rocm-libraries#6611 (commit 5375c0f)
[CK_TILE] Preserve input strides in EightWaves async-load descriptor (#6611) `MakeAsyncLoadADramWindow` in `GemmPipelineAgBgCrCompAsyncEightWavesPolicy` was rebuilding the 6D view descriptor with `make_naive_tensor_descriptor_packed`, which synthesizes strides from lengths and assumes a dense layout. When the input view's leading-dim stride is larger than its inner length (non-packed memory layout), the resulting tile window stepped through memory at the wrong stride. Compose the unmerge transforms on top of the input view's existing descriptor instead, so the actual runtime strides are preserved and the correct `element_space_size` is inherited for bounds checking. ## Test Plan Added an unit test showing the problem. ## Test Result The new test fails before fixes and passes after. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
9d34174ac2
commit
cbfb3e242e
@@ -176,10 +176,15 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
const index_t M0 = integer_divide_ceil(rows, M1);
|
||||
const auto row_lens = make_tuple(M0, number<M1>{});
|
||||
|
||||
const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
const auto desc_0 = decltype(d0)( // set correct size (without padding)
|
||||
d0.get_transforms(),
|
||||
tensor_view_tmp.get_tensor_descriptor().get_element_space_size());
|
||||
// Build the 6D view by composing unmerge transforms on top of the
|
||||
// input view's existing descriptor. This preserves the input's actual
|
||||
// strides (so a non-packed leading-dim stride is honored) and inherits
|
||||
// its element_space_size for bounds checking.
|
||||
const auto desc_0 = transform_tensor_descriptor(
|
||||
tensor_view_tmp.get_tensor_descriptor(),
|
||||
make_tuple(make_unmerge_transform(row_lens), make_unmerge_transform(col_lens)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{}));
|
||||
const auto desc_1 = transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
|
||||
Reference in New Issue
Block a user