mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
When possible, use the overload of load_tile_transpose that does not require assignment
This commit is contained in:
@@ -39,7 +39,7 @@ CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src)
|
||||
}
|
||||
else if constexpr(LoadTranspose)
|
||||
{
|
||||
dst = load_tile_transpose(src);
|
||||
load_tile_transpose(dst, src);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
|
||||
|
||||
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr);
|
||||
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
|
||||
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
@@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
@@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile( //
|
||||
|
||||
@@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
async_load_tile(q_lds_write_window, q_dram_window);
|
||||
async_load_tile(do_lds_write_window, do_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
@@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
s_waitcnt</*vmcnt=*/0>();
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
load_tile_transpose(kt_reg_tensor, kt_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
block_sync_lds();
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
@@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile( //
|
||||
|
||||
@@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
};
|
||||
|
||||
auto V_lds_load = [&](auto v_lds_read_idx) {
|
||||
kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
|
||||
load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx));
|
||||
};
|
||||
|
||||
decltype(m) m_old;
|
||||
|
||||
@@ -567,7 +567,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
// loop over along the [V]alue Sequence length
|
||||
move_tile_window(v_lds_read_window, {kK1, 0});
|
||||
v_tile = load_tile_transpose(v_lds_read_window);
|
||||
load_tile_transpose(v_tile, v_lds_read_window);
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});
|
||||
|
||||
@@ -107,7 +107,7 @@ struct GemmPipelineAgBgCrImplBase
|
||||
bool_constant<LoadTranspose> = {}) const
|
||||
{
|
||||
if constexpr(LoadTranspose)
|
||||
dst_block_tile = load_tile_transpose(lds_tile_window);
|
||||
load_tile_transpose(dst_block_tile, lds_tile_window);
|
||||
else
|
||||
load_tile(dst_block_tile, lds_tile_window);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user