When possible, use the overload of load_tile_transpose that does not require assignment

This commit is contained in:
Sami Aario
2026-01-02 15:43:35 +00:00
parent 321611081f
commit ca17ac3358
6 changed files with 17 additions and 17 deletions

View File

@@ -39,7 +39,7 @@ CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src)
}
else if constexpr(LoadTranspose)
{
dst = load_tile_transpose(src);
load_tile_transpose(dst, src);
}
else
{

View File

@@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr);
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
// STAGE 3, P^T@OGrad^T Gemm1
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
@@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
}
if constexpr(is_epilogue)
{
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
if constexpr(is_main_body)
@@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
auto kt_reg_tensor_slice = get_slice_tile( //

View File

@@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
async_load_tile(q_lds_write_window, q_dram_window);
async_load_tile(do_lds_write_window, do_dram_window);
__builtin_amdgcn_s_waitcnt(0);
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
q_reg_tensor = load_tile(q_lds_read_window);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
do_reg_tensor = load_tile(do_lds_read_window);
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
q_reg_tensor = load_tile(q_lds_read_window);
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
do_reg_tensor = load_tile(do_lds_read_window);
lse_block_tile = load_tile(lse_dram_window);
d_block_tile = load_tile(d_dram_window);
@@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
async_load_tile(v_lds_write_window, v_dram_window);
move_tile_window(v_dram_window, {kN0, 0});
s_waitcnt</*vmcnt=*/0>();
k_reg_tensor = load_tile(k_lds_read_window);
v_reg_tensor = load_tile(v_lds_read_window);
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
k_reg_tensor = load_tile(k_lds_read_window);
v_reg_tensor = load_tile(v_lds_read_window);
load_tile_transpose(kt_reg_tensor, kt_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
block_sync_lds();
if constexpr(is_epilogue)
{
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
if constexpr(is_main_body)
@@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
auto kt_reg_tensor_slice = get_slice_tile( //

View File

@@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline
};
auto V_lds_load = [&](auto v_lds_read_idx) {
kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx));
};
decltype(m) m_old;

View File

@@ -567,7 +567,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
// loop over along the [V]alue Sequence length
move_tile_window(v_lds_read_window, {kK1, 0});
v_tile = load_tile_transpose(v_lds_read_window);
load_tile_transpose(v_tile, v_lds_read_window);
});
// move back to the origin
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});

View File

@@ -107,7 +107,7 @@ struct GemmPipelineAgBgCrImplBase
bool_constant<LoadTranspose> = {}) const
{
if constexpr(LoadTranspose)
dst_block_tile = load_tile_transpose(lds_tile_window);
load_tile_transpose(dst_block_tile, lds_tile_window);
else
load_tile(dst_block_tile, lds_tile_window);
}