diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 2e907c2fa8..54becd3c0f 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -91,7 +91,7 @@ struct Default2DEpilogue CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, - void* = nullptr) + void* = nullptr) const { const auto storeOrUpdateTile = [&](const auto& o_tile) { // TODO: this is ugly diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 5adb64564d..dec5aa27b2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -707,18 +707,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); } - dk_epilogue(dk_dram_window, dk_acc); + dk_epilogue(dk_dram_window, dk_acc, nullptr); move_tile_window(dk_dram_window, {kN0, 0}); - dv_epilogue(dv_dram_window, dv_acc); + dv_epilogue(dv_dram_window, dv_acc, nullptr); move_tile_window(dv_dram_window, {kN0, 0}); } }; for(index_t i = 0; i < seqlen_kv_start; i += kN0) { - dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}); + dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr); move_tile_window(dk_dram_window, {kN0, 0}); - dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}); + dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr); move_tile_window(dv_dram_window, {kN0, 0}); } @@ -740,9 +740,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR const auto seqlen_kv_length = k_length.at(number<0>{}); for(; seqlen_kv_step < seqlen_kv_length; seqlen_kv_step += kN0) { - dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}); + dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr); move_tile_window(dk_dram_window, {kN0, 0}); - dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}); + dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr); move_tile_window(dv_dram_window, {kN0, 0}); } @@ -752,8 +752,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR dq_acc); else tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); - // static_assert(kIsDeterministic); - dq_epilogue(dq_dram_window, dq_acc); + dq_epilogue(dq_dram_window, dq_acc, nullptr); return; } };