mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-22 07:19:08 +00:00
[fix](kt-kernel): fix write_buffer do numa job (#1699)
This commit is contained in:
@@ -1304,16 +1304,19 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE_Common<AMX_K2_MOE_TP<K>> {
|
||||
throw std::runtime_error("Pointer arrays size must match gpu_tp_count");
|
||||
}
|
||||
|
||||
auto& config = this->config;
|
||||
auto pool = config.pool;
|
||||
// Each TP part writes to its corresponding buffer
|
||||
for (int tp_idx = 0; tp_idx < this->tp_count; tp_idx++) {
|
||||
pool->dispense_backend()->do_numa_job([this, pool, gpu_tp_count, gpu_experts_num,
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs](int numa_id) {
|
||||
// Note: w13 combines gate and up projections
|
||||
// Split w13 pointers for gate and up
|
||||
this->tps[tp_idx]->write_weights_to_buffer(
|
||||
this->tps[numa_id]->write_weights_to_buffer(
|
||||
gpu_tp_count, this->tp_count,
|
||||
gpu_experts_num, this->config,
|
||||
w13_weight_ptrs, w13_scale_ptrs, //gate + up use w13
|
||||
w2_weight_ptrs, w2_scale_ptrs); // down uses w2
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void merge_results(int qlen, void* output, bool incremental) {
|
||||
|
||||
Reference in New Issue
Block a user