diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 1cf3581859..23b5bec8d4 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -636,22 +636,41 @@ struct fmha_bwd_launcher if(traits_.is_group_mode) { + if(!seqstart_q_dev || !seqstart_k_dev) + throw std::runtime_error("fmha_bwd_launcher::prepare_workspace_async: " + "seqstart_q_dev and seqstart_k_dev are required in " + "group mode"); HIP_CHECK_ERROR(hipMemcpyAsync( pin_q, seqstart_q_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream)); HIP_CHECK_ERROR(hipMemcpyAsync( pin_k, seqstart_k_dev, seqstart_bytes, hipMemcpyDeviceToHost, stream)); } - auto* pack_closure = new std::function( + auto pack_closure = std::make_unique>( [=, fn = pack_workspace_host_]() { fn(pin_w, seqstart_q_pinned, seqstart_k_pinned); }); + // Callback runs on the HIP driver helper thread across a C ABI boundary; + // any exception escaping it would call std::terminate. HIP_CHECK_ERROR(hipLaunchHostFunc( stream, [](void* ud) { - auto* c = static_cast*>(ud); - (*c)(); - delete c; + std::unique_ptr> c{static_cast*>(ud)}; + try + { + (*c)(); + } + catch(const std::exception& e) + { + std::cerr << "fmha_bwd_launcher: pack_workspace_host threw: " << e.what() + << '\n'; + } + catch(...) + { + std::cerr << "fmha_bwd_launcher: pack_workspace_host threw unknown\n"; + } }, - pack_closure)); + pack_closure.get())); + // Ownership transferred to the callback only after a successful launch. + pack_closure.release(); HIP_CHECK_ERROR( hipMemcpyAsync(device_ws_ptr, pin_w, host_ws_size_, hipMemcpyHostToDevice, stream)); diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 3a328b96d7..e4c8b63717 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -418,7 +418,7 @@ bwd_result fmha_bwd_run(mode_enum mode, }; ck_tile::gpu_timer prepare_ws_timer; - prepare_ws_timer.start(nullptr); + prepare_ws_timer.start(stream_config.stream_id_); launcher.prepare_workspace_async( ws_buf.GetDeviceBuffer(), (mode == mode_enum::group) ? static_cast(seqstart_q.GetDeviceBuffer()) @@ -427,7 +427,7 @@ bwd_result fmha_bwd_run(mode_enum mode, : nullptr, stream_config, pinned_host_alloc); - prepare_ws_timer.stop(nullptr); + prepare_ws_timer.stop(stream_config.stream_id_); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 8716c93bfb..8391a14832 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -323,7 +323,7 @@ struct FmhaBwdDQDKDVKernel std::forward(args)...); } template - CK_TILE_HOST static constexpr auto GetWorkspaceDeviceSizeUpperBound(Args&&... args) + CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(Args&&... args) { return WorkspaceManager::template GetWorkspaceDeviceSizeUpperBound< kUseQrQtrDorPipeline,