mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#4294 (commit 6601702)
Cleanup and refactoring related to tile loading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Cleanup and refactoring done while implementing mixed precision for fp16/bf16 x fp8 Key changes: - Renamed load_interleaved_pk_type.hpp to load_and_convert_tile.hpp and refactored the API to use consistent naming conventions - Updated load_tile_transpose functions to use output parameters instead of return values for consistency - Removed unused variable declarations and simplified type deduction logic - Define load_tile_with_elementwise to use tuple types explicitly for clarity ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [x] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [X] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
0438ab1b79
commit
95dc496d30
@@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
|
||||
|
||||
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr);
|
||||
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
|
||||
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
@@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
@@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile( //
|
||||
|
||||
@@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
async_load_tile(q_lds_write_window, q_dram_window);
|
||||
async_load_tile(do_lds_write_window, do_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
@@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
s_waitcnt</*vmcnt=*/0>();
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
load_tile_transpose(kt_reg_tensor, kt_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
block_sync_lds();
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
@@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile( //
|
||||
|
||||
@@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
};
|
||||
|
||||
auto V_lds_load = [&](auto v_lds_read_idx) {
|
||||
kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
|
||||
load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx));
|
||||
};
|
||||
|
||||
decltype(m) m_old;
|
||||
|
||||
@@ -582,7 +582,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
// loop over along the [V]alue Sequence length
|
||||
move_tile_window(v_lds_read_window, {kK1, 0});
|
||||
v_tile = load_tile_transpose(v_lds_read_window);
|
||||
load_tile_transpose(v_tile, v_lds_read_window);
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});
|
||||
|
||||
Reference in New Issue
Block a user