[FIX] fix on fmha_bwd (#2784)

* fix on fmha_bwd

* Add 'const' to the Default2DEpilogue call operator

* Fix more calls to Default2DEpilogue

---------

Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
Co-authored-by: Yi DING <yi.ding@amd.com>
This commit is contained in:
Thomas Ning
2025-09-08 02:31:27 -04:00
committed by GitHub
parent e279e9420e
commit 42a43d1523
2 changed files with 8 additions and 9 deletions

View File

@@ -707,18 +707,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
dk_epilogue(dk_dram_window, dk_acc);
dk_epilogue(dk_dram_window, dk_acc, nullptr);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, dv_acc);
dv_epilogue(dv_dram_window, dv_acc, nullptr);
move_tile_window(dv_dram_window, {kN0, 0});
}
};
for(index_t i = 0; i < seqlen_kv_start; i += kN0)
{
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr);
move_tile_window(dv_dram_window, {kN0, 0});
}
@@ -740,9 +740,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
const auto seqlen_kv_length = k_length.at(number<0>{});
for(; seqlen_kv_step < seqlen_kv_length; seqlen_kv_step += kN0)
{
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr);
move_tile_window(dv_dram_window, {kN0, 0});
}
@@ -752,8 +752,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
dq_acc);
else
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
// static_assert(kIsDeterministic);
dq_epilogue(dq_dram_window, dq_acc);
dq_epilogue(dq_dram_window, dq_acc, nullptr);
return;
}
};