diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index c747786e..ab477ebf 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -3810,7 +3810,22 @@ void server_context::speculative_decoding_accept() { apply_server_biases(slot); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted); + std::vector ids; + try { + ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted); + } catch (const std::exception & e) { + LOG_ERROR("speculative sampling failed, releasing slot", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"error", e.what()}, + }); + send_error(slot, std::string("sampling error: ") + e.what(), ERROR_TYPE_SERVER); + slot.release(); + slot.i_batch = -1; + slot.i_batch_dft.clear(); + slot.drafted.clear(); + continue; + } int32_t mtp_n_past_base = 0; std::vector mtp_hidden_state_pre; @@ -4320,9 +4335,21 @@ void server_context::process_batch_tokens(int32_t & n_batch) { apply_server_biases(slot); - const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx); - - common_sampler_accept(slot.ctx_sampling, ctx, id, true); + llama_token id; + try { + id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx); + common_sampler_accept(slot.ctx_sampling, ctx, id, true); + } catch (const std::exception & e) { + LOG_ERROR("sampling failed, releasing slot", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"error", e.what()}, + }); + send_error(slot, std::string("sampling error: ") + e.what(), ERROR_TYPE_SERVER); + slot.release(); + slot.i_batch = -1; + continue; + } slot.n_decoded += 1; const int64_t t_current = ggml_time_us();