[rocm-libraries] ROCm/rocm-libraries#4294 (commit 6601702)

Cleanup and refactoring related to tile loading
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Cleanup and refactoring done while implementing mixed precision for
fp16/bf16 x fp8

Key changes:

- Renamed load_interleaved_pk_type.hpp to load_and_convert_tile.hpp and
refactored the API to use consistent naming conventions
- Updated load_tile_transpose functions to use output parameters instead
of return values for consistency
- Removed unused variable declarations and simplified type deduction
logic
- Define load_tile_with_elementwise to use tuple types explicitly for
clarity

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [X] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
SamiAario-AMD
2026-03-02 12:21:44 +00:00
committed by assistant-librarian[bot]
parent 0438ab1b79
commit 95dc496d30
47 changed files with 190 additions and 182 deletions

View File

@@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr);
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
// STAGE 3, P^T@OGrad^T Gemm1
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
@@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
}
if constexpr(is_epilogue)
{
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
if constexpr(is_main_body)
@@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
auto kt_reg_tensor_slice = get_slice_tile( //

View File

@@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
async_load_tile(q_lds_write_window, q_dram_window);
async_load_tile(do_lds_write_window, do_dram_window);
__builtin_amdgcn_s_waitcnt(0);
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
q_reg_tensor = load_tile(q_lds_read_window);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
do_reg_tensor = load_tile(do_lds_read_window);
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
q_reg_tensor = load_tile(q_lds_read_window);
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
do_reg_tensor = load_tile(do_lds_read_window);
lse_block_tile = load_tile(lse_dram_window);
d_block_tile = load_tile(d_dram_window);
@@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
async_load_tile(v_lds_write_window, v_dram_window);
move_tile_window(v_dram_window, {kN0, 0});
s_waitcnt</*vmcnt=*/0>();
k_reg_tensor = load_tile(k_lds_read_window);
v_reg_tensor = load_tile(v_lds_read_window);
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
k_reg_tensor = load_tile(k_lds_read_window);
v_reg_tensor = load_tile(v_lds_read_window);
load_tile_transpose(kt_reg_tensor, kt_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
block_sync_lds();
if constexpr(is_epilogue)
{
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
if constexpr(is_main_body)
@@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
auto kt_reg_tensor_slice = get_slice_tile( //

View File

@@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline
};
auto V_lds_load = [&](auto v_lds_read_idx) {
kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx));
};
decltype(m) m_old;

View File

@@ -582,7 +582,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
// loop over along the [V]alue Sequence length
move_tile_window(v_lds_read_window, {kK1, 0});
v_tile = load_tile_transpose(v_lds_read_window);
load_tile_transpose(v_tile, v_lds_read_window);
});
// move back to the origin
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});