diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index eb22fbb5a2..4e05ecc59c 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -39,7 +39,7 @@ CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) } else if constexpr(LoadTranspose) { - dst = load_tile_transpose(src); + load_tile_transpose(dst, src); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 9aeabaa8c2..16212c0d13 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR s_acc = gemm_0(q_reg_tensor, k_reg_tensor); dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); } if constexpr(is_epilogue) { @@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); // STAGE 3, P^T@OGrad^T Gemm1 auto pt_reg_tensor = make_static_distributed_tensor( @@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR } if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 3d21928ced..37b4ae41a3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(q_lds_write_window, q_dram_window); async_load_tile(do_lds_write_window, do_dram_window); __builtin_amdgcn_s_waitcnt(0); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); - q_reg_tensor = load_tile(q_lds_read_window); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); - do_reg_tensor = load_tile(do_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); + q_reg_tensor = load_tile(q_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); + do_reg_tensor = load_tile(do_lds_read_window); lse_block_tile = load_tile(lse_dram_window); d_block_tile = load_tile(d_dram_window); @@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(v_lds_write_window, v_dram_window); move_tile_window(v_dram_window, {kN0, 0}); s_waitcnt(); - k_reg_tensor = load_tile(k_lds_read_window); - v_reg_tensor = load_tile(v_lds_read_window); - kt_reg_tensor = load_tile_transpose(kt_lds_read_window); + k_reg_tensor = load_tile(k_lds_read_window); + v_reg_tensor = load_tile(v_lds_read_window); + load_tile_transpose(kt_reg_tensor, kt_lds_read_window); } if constexpr(is_epilogue) { @@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR block_sync_lds(); if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index c25f57632f..4cca604ff1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline }; auto V_lds_load = [&](auto v_lds_read_idx) { - kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx)); }; decltype(m) m_old; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 26662dafeb..3e958ea531 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -567,7 +567,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload // loop over along the [V]alue Sequence length move_tile_window(v_lds_read_window, {kK1, 0}); - v_tile = load_tile_transpose(v_lds_read_window); + load_tile_transpose(v_tile, v_lds_read_window); }); // move back to the origin move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index ffac45bed4..358101d1db 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -107,7 +107,7 @@ struct GemmPipelineAgBgCrImplBase bool_constant = {}) const { if constexpr(LoadTranspose) - dst_block_tile = load_tile_transpose(lds_tile_window); + load_tile_transpose(dst_block_tile, lds_tile_window); else load_tile(dst_block_tile, lds_tile_window); }