mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-20 21:24:08 +00:00
cuda: reduce qwen3next moe/ssm sync overhead and refresh eval
This commit is contained in:
@@ -315,3 +315,70 @@ CUDA_VISIBLE_DEVICES=0,1 /ik_llama.cpp/build-cuda13-fresh/bin/llama-sweep-bench
|
||||
- Remaining likely bottlenecks for 16GB PP:
|
||||
- MoE routing still limited by per-expert launches/host-side per-expert loop in `mul_mat_id`.
|
||||
- Scheduler split / backend-crossing overhead remains visible at this config.
|
||||
|
||||
## 2026-02-06 Follow-up Hotspot Pass (this session)
|
||||
|
||||
### Additional code changes
|
||||
|
||||
1. `ggml/src/ggml-cuda.cu`
|
||||
- Removed an unused `ids` device->host copy + stream sync in `ggml_cuda_moe_up_gate_unary` fallback path.
|
||||
- Reduced row-mapping host transfer volume by deriving `moe_counts` from host-side prefix bounds (`cum_moe_counts`) instead of copying both arrays from device.
|
||||
- Added `build_active_experts(...)` and switched per-expert loops to iterate only active experts.
|
||||
2. `ggml/src/ggml-cuda/ssm-conv.cu`
|
||||
- Removed host-side `cudaMemcpyAsync(...D2H...) + cudaStreamSynchronize` for multi-seq fast-path eligibility.
|
||||
- Made fast/fallback dispatch fully async by gating both kernels with a device-side `fast_path_ok` flag.
|
||||
3. `ggml/src/ggml-backend.cpp`
|
||||
- Reduced unnecessary split churn when a weight tensor is on another backend but the current backend can consume that buffer type directly.
|
||||
- Increased `GGML_SCHED_MAX_SPLITS` from `2048` to `4096` for large-graph headroom.
|
||||
4. `src/llama.cpp`
|
||||
- Added a Qwen3Next-specific default split guard for heterogeneous dual-GPU layer mode: clamp to at least `75/25` on 2-GPU auto-split when GPU0 has more free memory.
|
||||
5. `scripts/qwen3next-eval.sh`
|
||||
- Fixed CLI compatibility (`mainline: llama-completion`, `ik: llama-cli` completion path).
|
||||
- Made evaluation resilient to missing binaries (`gpu_sweep_mainline` is skipped if unavailable).
|
||||
- Fixed complexity-token regex.
|
||||
- Switched PPL corpus generation to a stable deterministic pattern to reduce chunk-level variance.
|
||||
|
||||
### Validation rerun
|
||||
|
||||
Run artifact: `/tmp/qwen3next-eval/20260206_064339`
|
||||
|
||||
- CPU PPL parity:
|
||||
- chunks=1: mainline `1.0009`, ik `1.0009`, delta `0.000000`
|
||||
- chunks=2: mainline `1.0005`, ik `1.0005`, delta `0.000000`
|
||||
- CUDA sanity parity:
|
||||
- `gpu_ppl_chunks1_mainline`: `OK`
|
||||
- `gpu_ppl_chunks1_ik`: `OK`
|
||||
- Generation smoke:
|
||||
- both mainline and ik contain Fibonacci token(s)
|
||||
- mainline contains complexity token(s), ik did not in this sample output
|
||||
- Notes:
|
||||
- `gpu_sweep_mainline` skipped in this environment because `/home/yurko/Code/llama.cpp/build/bin/llama-sweep-bench` is not present.
|
||||
- `gpu_sweep_ik` (`c=2048`, `n=32`) in this run peaked at approximately `maxPP=137.02`, `maxTG=24.81`.
|
||||
|
||||
### Quick matrix (exact required configs)
|
||||
|
||||
Run artifact: `/tmp/qwen3next-matrix/20260206_063957`
|
||||
|
||||
| Profile | Baseline maxPP | Baseline maxTG | New maxPP | New maxTG | Delta maxPP | Delta maxTG |
|
||||
| --- | ---: | ---: | ---: | ---: | ---: | ---: |
|
||||
| 16GB a) `CUDA_VISIBLE_DEVICES=0 --cpu-moe` | 129.83 | 26.45 | 115.56 | 25.74 | -14.27 | -0.71 |
|
||||
| 16GB b) `CUDA_VISIBLE_DEVICES=0 --cpu-moe -no-ooae` | n/a | n/a | 136.21 | 26.00 | n/a | n/a |
|
||||
| 28GB a) `CUDA_VISIBLE_DEVICES=0,1 --cpu-moe --tensor-split 0.85,0.15` | 127.66 | 22.95 | 129.70 | 22.72 | +2.04 | -0.23 |
|
||||
| 28GB b) `CUDA_VISIBLE_DEVICES=0,1 --cpu-moe` | n/a | n/a | 117.54 | 22.99 | n/a | n/a |
|
||||
|
||||
### Variance note for single-GPU default (`--cpu-moe`)
|
||||
|
||||
Repeated measurements show substantial run-to-run variance in this environment:
|
||||
|
||||
Run artifact: `/tmp/qwen3next-repeat-20260206_064133`
|
||||
|
||||
- `single_cpu_moe` maxPP/maxTG:
|
||||
- run1: `113.84 / 25.86`
|
||||
- run2: `135.29 / 26.88`
|
||||
- run3: `113.95 / 23.54`
|
||||
- `single_cpu_moe_no_ooae` maxPP/maxTG:
|
||||
- run1: `135.33 / 26.49`
|
||||
- run2: `133.64 / 24.92`
|
||||
- run3: `126.33 / 23.42`
|
||||
|
||||
Interpretation: in this setup, `-no-ooae` is currently more stable and generally faster for PP; default OOAE shows large variance and occasional severe PP drops.
|
||||
|
||||
@@ -1103,7 +1103,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_SPLITS
|
||||
#define GGML_SCHED_MAX_SPLITS 2048
|
||||
#define GGML_SCHED_MAX_SPLITS 4096
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
|
||||
@@ -1731,7 +1731,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
// by starting a new split, the memory of the previously offloaded weights can be reused
|
||||
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
||||
int src_backend_id = tensor_backend_id(src);
|
||||
if (src_backend_id != cur_backend_id) {
|
||||
bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);
|
||||
if (src_backend_id != cur_backend_id && !supported) {
|
||||
need_new_split = true;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -2386,6 +2386,19 @@ static __global__ void k_moe_row_scatter(
|
||||
}
|
||||
}
|
||||
|
||||
static inline void build_active_experts(
|
||||
const std::vector<int> & moe_counts,
|
||||
std::vector<int32_t> & active_experts) {
|
||||
active_experts.clear();
|
||||
active_experts.reserve(moe_counts.size());
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) moe_counts.size(); ++i) {
|
||||
if (moe_counts[i] > 0) {
|
||||
active_experts.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
|
||||
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
|
||||
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
|
||||
@@ -2436,11 +2449,14 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
int32_t has_invalid_ids = 0;
|
||||
CUDA_CHECK(cudaMemcpyAsync(moe_counts.data(), dev_moe_counts.get(), n_as*sizeof(int), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(cum_moe_counts.data(), dev_cum_moe_counts.get(), (n_as + 1)*sizeof(int), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(&has_invalid_ids, dev_has_invalid_ids.get(), sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) n_as; ++i) {
|
||||
moe_counts[i] = cum_moe_counts[i + 1] - cum_moe_counts[i];
|
||||
}
|
||||
|
||||
return has_invalid_ids != 0;
|
||||
}
|
||||
|
||||
@@ -2558,6 +2574,8 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
if (is_ser) {
|
||||
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
|
||||
}
|
||||
std::vector<int32_t> active_experts;
|
||||
build_active_experts(moe_counts, active_experts);
|
||||
|
||||
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||
ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||
@@ -2565,14 +2583,10 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
src1_row.data = src1_contiguous.get();
|
||||
dst_row.data = dst_contiguous.get();
|
||||
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
for (int32_t i02 : active_experts) {
|
||||
|
||||
int64_t num_src1_rows = moe_counts[i02];
|
||||
|
||||
if (num_src1_rows == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t mapping_offset = cum_moe_counts[i02];
|
||||
|
||||
{
|
||||
@@ -2895,11 +2909,6 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
return i;
|
||||
}
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
const char * ids_dev = (const char *) ids->data;
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
ggml_tensor src0_1_row = *src0_1;
|
||||
ggml_tensor src0_2_row; if (src0_2) src0_2_row = *src0_2;
|
||||
ggml_tensor src1_row = *src1;
|
||||
@@ -2986,11 +2995,12 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> active_experts;
|
||||
build_active_experts(moe_counts, active_experts);
|
||||
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
for (int32_t i02 : active_experts) {
|
||||
int64_t num_src1_rows = moe_counts[i02];
|
||||
|
||||
if (num_src1_rows == 0) continue;
|
||||
size_t mapping_offset = cum_moe_counts[i02];
|
||||
|
||||
if (use_quantized_src1) {
|
||||
|
||||
@@ -205,12 +205,17 @@ static __global__ void ssm_conv_multi_seq_unique_f32_kernel(
|
||||
const float * src1,
|
||||
const float * src2,
|
||||
const int32_t * seq_ids,
|
||||
const int32_t * fast_path_ok,
|
||||
float * dst_x,
|
||||
float * dst_state,
|
||||
int nc,
|
||||
int nr,
|
||||
int n_t,
|
||||
int src1_nb1) {
|
||||
if (fast_path_ok != nullptr && fast_path_ok[0] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int t = blockIdx.y;
|
||||
|
||||
@@ -241,11 +246,16 @@ static __global__ void ssm_conv_multi_seq_unique_f32_kernel_nc4(
|
||||
const float * src1,
|
||||
const float * src2,
|
||||
const int32_t * seq_ids,
|
||||
const int32_t * fast_path_ok,
|
||||
float * dst_x,
|
||||
float * dst_state,
|
||||
int nr,
|
||||
int n_t,
|
||||
int src1_nb1) {
|
||||
if (fast_path_ok != nullptr && fast_path_ok[0] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int t = blockIdx.y;
|
||||
|
||||
@@ -276,6 +286,7 @@ static __global__ void ssm_conv_f32_kernel(
|
||||
const float * src1,
|
||||
const float * src2,
|
||||
const int32_t * src3,
|
||||
const int32_t * fast_path_ok,
|
||||
float * dst_x,
|
||||
float * dst_state,
|
||||
int nc,
|
||||
@@ -284,6 +295,10 @@ static __global__ void ssm_conv_f32_kernel(
|
||||
int n_kv,
|
||||
int src1_nb1,
|
||||
int src3_nb1) {
|
||||
if (fast_path_ok != nullptr && fast_path_ok[0] != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (row >= nr) {
|
||||
return;
|
||||
@@ -338,6 +353,7 @@ static __global__ void ssm_conv_f32_kernel_nc4(
|
||||
const float * src1,
|
||||
const float * src2,
|
||||
const int32_t * src3,
|
||||
const int32_t * fast_path_ok,
|
||||
float * dst_x,
|
||||
float * dst_state,
|
||||
int nr,
|
||||
@@ -345,6 +361,10 @@ static __global__ void ssm_conv_f32_kernel_nc4(
|
||||
int n_kv,
|
||||
int src1_nb1,
|
||||
int src3_nb1) {
|
||||
if (fast_path_ok != nullptr && fast_path_ok[0] != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (row >= nr) {
|
||||
return;
|
||||
@@ -441,6 +461,8 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
const dim3 block_dims(CUDA_SSM_CONV_BLOCK_SIZE, 1, 1);
|
||||
const dim3 row_grid((nr + CUDA_SSM_CONV_BLOCK_SIZE - 1) / CUDA_SSM_CONV_BLOCK_SIZE, 1, 1);
|
||||
ggml_cuda_pool_alloc<int32_t> fast_path_ok_d(ctx.pool());
|
||||
const int32_t * multi_seq_fast_path_ok = nullptr;
|
||||
|
||||
// Fast path for single-sequence recurrent updates (Qwen3Next prompt/decode path).
|
||||
// In this case, outputs are independent given the initial conv state, so we parallelize over token blocks.
|
||||
@@ -499,8 +521,8 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
// one token per unique sequence, no copy-to-multiple-sequences routing.
|
||||
ggml_cuda_pool_alloc<int32_t> seq_ids(ctx.pool(), n_t);
|
||||
ggml_cuda_pool_alloc<int32_t> seq_seen(ctx.pool(), n_kv);
|
||||
ggml_cuda_pool_alloc<int32_t> fast_path_ok_d(ctx.pool(), 1);
|
||||
int32_t fast_path_ok = 1;
|
||||
fast_path_ok_d.alloc(1);
|
||||
|
||||
CUDA_CHECK(cudaMemsetAsync(seq_seen.get(), 0, n_kv * sizeof(int32_t), ctx.stream()));
|
||||
CUDA_CHECK(cudaMemcpyAsync(fast_path_ok_d.get(), &fast_path_ok, sizeof(int32_t), cudaMemcpyHostToDevice, ctx.stream()));
|
||||
@@ -515,35 +537,32 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
n_t,
|
||||
n_kv,
|
||||
src3->nb[1] / sizeof(int32_t));
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(&fast_path_ok, fast_path_ok_d.get(), sizeof(int32_t), cudaMemcpyDeviceToHost, ctx.stream()));
|
||||
CUDA_CHECK(cudaStreamSynchronize(ctx.stream()));
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
multi_seq_fast_path_ok = fast_path_ok_d.get();
|
||||
|
||||
if (fast_path_ok) {
|
||||
const dim3 token_grid(row_grid.x, n_t, 1);
|
||||
if (nc == 4) {
|
||||
ssm_conv_multi_seq_unique_f32_kernel_nc4<<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
seq_ids.get(),
|
||||
dst_x,
|
||||
dst_state,
|
||||
nr, n_t,
|
||||
src1->nb[1] / sizeof(float));
|
||||
} else {
|
||||
ssm_conv_multi_seq_unique_f32_kernel<<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
seq_ids.get(),
|
||||
dst_x,
|
||||
dst_state,
|
||||
nc, nr, n_t,
|
||||
src1->nb[1] / sizeof(float));
|
||||
}
|
||||
return;
|
||||
const dim3 token_grid(row_grid.x, n_t, 1);
|
||||
if (nc == 4) {
|
||||
ssm_conv_multi_seq_unique_f32_kernel_nc4<<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
seq_ids.get(),
|
||||
multi_seq_fast_path_ok,
|
||||
dst_x,
|
||||
dst_state,
|
||||
nr, n_t,
|
||||
src1->nb[1] / sizeof(float));
|
||||
} else {
|
||||
ssm_conv_multi_seq_unique_f32_kernel<<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
seq_ids.get(),
|
||||
multi_seq_fast_path_ok,
|
||||
dst_x,
|
||||
dst_state,
|
||||
nc, nr, n_t,
|
||||
src1->nb[1] / sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -554,6 +573,7 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
(const int32_t *) src3->data,
|
||||
multi_seq_fast_path_ok,
|
||||
dst_x,
|
||||
dst_state,
|
||||
nr, n_t, n_kv,
|
||||
@@ -565,6 +585,7 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
(const int32_t *) src3->data,
|
||||
nullptr,
|
||||
dst_x,
|
||||
dst_state,
|
||||
nr, n_t, n_kv,
|
||||
@@ -577,6 +598,7 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
(const int32_t *) src3->data,
|
||||
multi_seq_fast_path_ok,
|
||||
dst_x,
|
||||
dst_state,
|
||||
nc, nr, n_t, n_kv,
|
||||
|
||||
468
scripts/qwen3next-eval.sh
Executable file
468
scripts/qwen3next-eval.sh
Executable file
@@ -0,0 +1,468 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
IMAGE="${IMAGE:-iktest-dev:latest}"
|
||||
MAIN_REPO="${MAIN_REPO:-/home/yurko/Code/llama.cpp}"
|
||||
IK_REPO="${IK_REPO:-/home/yurko/Code/ik_llama.cpp}"
|
||||
MODEL_HOST="${MODEL_HOST:-/home/yurko/.cache/llama.cpp/qwen3-next-coder.gguf}"
|
||||
OUT_ROOT="${OUT_ROOT:-/tmp/qwen3next-eval}"
|
||||
WITH_GPU=0
|
||||
GPU_DEVICE="${GPU_DEVICE:-0}"
|
||||
SWEEP_CTX="${SWEEP_CTX:-2048}"
|
||||
SWEEP_N="${SWEEP_N:-32}"
|
||||
|
||||
usage() {
|
||||
cat <<'USAGE'
|
||||
Usage:
|
||||
scripts/qwen3next-eval.sh [options]
|
||||
|
||||
Options:
|
||||
--with-gpu Enable GPU checks in addition to CPU checks.
|
||||
--gpu-device ID CUDA device id to use for GPU sanity checks (default: 0).
|
||||
--image IMAGE Docker image to run checks in (default: iktest-dev:latest).
|
||||
--main-repo PATH Mainline repo path (default: /home/yurko/Code/llama.cpp).
|
||||
--ik-repo PATH ik repo path (default: /home/yurko/Code/ik_llama.cpp).
|
||||
--model PATH Host path to model GGUF file.
|
||||
--out-root PATH Output root directory (default: /tmp/qwen3next-eval).
|
||||
--sweep-ctx N Sweep context size for PP/TG check (default: 2048).
|
||||
--sweep-n N Sweep generation tokens (default: 32).
|
||||
-h, --help Show this help.
|
||||
|
||||
What this script runs (in this order):
|
||||
1) CPU perplexity parity (chunks=1) mainline -> ik
|
||||
2) CPU perplexity parity (chunks=2) mainline -> ik
|
||||
3) CPU short generation smoke quality mainline -> ik
|
||||
4) Optional GPU sanity checks mainline -> ik
|
||||
|
||||
Output:
|
||||
A timestamped folder is created under OUT_ROOT with:
|
||||
- SUMMARY.md
|
||||
- run.log
|
||||
- *.out / *.err logs for each command
|
||||
USAGE
|
||||
}
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--with-gpu)
|
||||
WITH_GPU=1
|
||||
shift
|
||||
;;
|
||||
--gpu-device)
|
||||
GPU_DEVICE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--image)
|
||||
IMAGE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--main-repo)
|
||||
MAIN_REPO="$2"
|
||||
shift 2
|
||||
;;
|
||||
--ik-repo)
|
||||
IK_REPO="$2"
|
||||
shift 2
|
||||
;;
|
||||
--model)
|
||||
MODEL_HOST="$2"
|
||||
shift 2
|
||||
;;
|
||||
--out-root)
|
||||
OUT_ROOT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--sweep-ctx)
|
||||
SWEEP_CTX="$2"
|
||||
shift 2
|
||||
;;
|
||||
--sweep-n)
|
||||
SWEEP_N="$2"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ ! -d "$MAIN_REPO" ]]; then
|
||||
echo "Mainline repo does not exist: $MAIN_REPO" >&2
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -d "$IK_REPO" ]]; then
|
||||
echo "ik repo does not exist: $IK_REPO" >&2
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -f "$MODEL_HOST" ]]; then
|
||||
echo "Model file does not exist: $MODEL_HOST" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
run_id="$(date +%Y%m%d_%H%M%S)"
|
||||
out_dir="${OUT_ROOT%/}/${run_id}"
|
||||
mkdir -p "$out_dir"
|
||||
|
||||
cat > "${out_dir}/ppl_input.txt" <<'TXT'
|
||||
Deterministic evaluation text for quick perplexity parity checks.
|
||||
The next lines intentionally repeat a simple pattern to reduce variance.
|
||||
TXT
|
||||
for _ in $(seq 1 400); do
|
||||
echo "the system writes logs and the system reads logs" >> "${out_dir}/ppl_input.txt"
|
||||
done
|
||||
|
||||
cat > "${out_dir}/gen_prompt.txt" <<'TXT'
|
||||
Write a concise Python function that returns the first n Fibonacci numbers iteratively, and then give one sentence explaining time complexity.
|
||||
TXT
|
||||
|
||||
cat > "${out_dir}/run_inside.sh" <<'BASH'
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
WITH_GPU="${WITH_GPU:-0}"
|
||||
GPU_DEVICE="${GPU_DEVICE:-0}"
|
||||
SWEEP_CTX="${SWEEP_CTX:-2048}"
|
||||
SWEEP_N="${SWEEP_N:-32}"
|
||||
|
||||
MAIN_LD="/mainline/build/bin"
|
||||
IK_LD="/ik/build/src:/ik/build/ggml/src:/ik/build/examples/mtmd"
|
||||
MODEL="/model.gguf"
|
||||
|
||||
RUN_LOG="/out/run.log"
|
||||
STATUS_FILE="/out/status.tsv"
|
||||
|
||||
touch "$RUN_LOG"
|
||||
printf "name\tstatus\texit_code\thost_mem_used_before_mib\thost_mem_used_after_mib\tgpu_mem_used_before_mib\tgpu_mem_used_after_mib\tmax_rss_kib\telapsed\n" > "$STATUS_FILE"
|
||||
|
||||
log() {
|
||||
local msg="$1"
|
||||
printf "[%s] %s\n" "$(date +%H:%M:%S)" "$msg" | tee -a "$RUN_LOG"
|
||||
}
|
||||
|
||||
require_bin() {
|
||||
local path="$1"
|
||||
if [[ ! -x "$path" ]]; then
|
||||
log "MISSING: $path"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
host_mem_used_mib() {
|
||||
awk '
|
||||
/MemTotal:/ { mt = $2 }
|
||||
/MemAvailable:/ { ma = $2 }
|
||||
END {
|
||||
if (mt > 0 && ma >= 0) {
|
||||
printf "%.1f", (mt - ma) / 1024.0
|
||||
} else {
|
||||
print "NA"
|
||||
}
|
||||
}
|
||||
' /proc/meminfo
|
||||
}
|
||||
|
||||
gpu_mem_used_mib() {
|
||||
if [[ "$WITH_GPU" != "1" ]]; then
|
||||
echo "NA"
|
||||
return
|
||||
fi
|
||||
if ! command -v nvidia-smi >/dev/null 2>&1; then
|
||||
echo "NA"
|
||||
return
|
||||
fi
|
||||
local used
|
||||
used="$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits 2>/dev/null | tr '\n' ',' | sed 's/,$//' || true)"
|
||||
if [[ -z "$used" ]]; then
|
||||
echo "NA"
|
||||
else
|
||||
echo "$used"
|
||||
fi
|
||||
}
|
||||
|
||||
extract_max_rss_kib() {
|
||||
local time_file="$1"
|
||||
if [[ ! -f "$time_file" ]]; then
|
||||
echo "NA"
|
||||
return
|
||||
fi
|
||||
local rss
|
||||
rss="$(grep -E '^Maximum resident set size' "$time_file" | awk '{print $6}' | tail -n1 || true)"
|
||||
if [[ -z "$rss" ]]; then
|
||||
echo "NA"
|
||||
else
|
||||
echo "$rss"
|
||||
fi
|
||||
}
|
||||
|
||||
extract_elapsed() {
|
||||
local time_file="$1"
|
||||
if [[ ! -f "$time_file" ]]; then
|
||||
echo "NA"
|
||||
return
|
||||
fi
|
||||
local elapsed
|
||||
elapsed="$(grep -E '^Elapsed \(wall clock\) time' "$time_file" | sed -E 's/^[^:]+:[[:space:]]*//' | tail -n1 || true)"
|
||||
if [[ -z "$elapsed" ]]; then
|
||||
echo "NA"
|
||||
else
|
||||
echo "$elapsed"
|
||||
fi
|
||||
}
|
||||
|
||||
run_cmd() {
|
||||
local name="$1"
|
||||
shift
|
||||
local out_file="/out/${name}.out"
|
||||
local err_file="/out/${name}.err"
|
||||
local time_file="/out/${name}.time"
|
||||
local ec
|
||||
local host_before host_after gpu_before gpu_after max_rss elapsed
|
||||
|
||||
host_before="$(host_mem_used_mib)"
|
||||
gpu_before="$(gpu_mem_used_mib)"
|
||||
log "RUN: $name"
|
||||
|
||||
set +e
|
||||
if [[ -x /usr/bin/time ]]; then
|
||||
/usr/bin/time -v -o "$time_file" "$@" >"$out_file" 2>"$err_file"
|
||||
ec=$?
|
||||
else
|
||||
"$@" >"$out_file" 2>"$err_file"
|
||||
ec=$?
|
||||
fi
|
||||
set -e
|
||||
|
||||
host_after="$(host_mem_used_mib)"
|
||||
gpu_after="$(gpu_mem_used_mib)"
|
||||
max_rss="$(extract_max_rss_kib "$time_file")"
|
||||
elapsed="$(extract_elapsed "$time_file")"
|
||||
|
||||
if [[ $ec -eq 0 ]]; then
|
||||
printf "%s\tOK\t0\t%s\t%s\t%s\t%s\t%s\t%s\n" \
|
||||
"$name" "$host_before" "$host_after" "$gpu_before" "$gpu_after" "$max_rss" "$elapsed" >> "$STATUS_FILE"
|
||||
log "OK: $name"
|
||||
else
|
||||
printf "%s\tFAIL\t%d\t%s\t%s\t%s\t%s\t%s\t%s\n" \
|
||||
"$name" "$ec" "$host_before" "$host_after" "$gpu_before" "$gpu_after" "$max_rss" "$elapsed" >> "$STATUS_FILE"
|
||||
log "FAIL($ec): $name"
|
||||
fi
|
||||
return $ec
|
||||
}
|
||||
|
||||
extract_ppl() {
|
||||
local out_file="$1"
|
||||
local err_file="$2"
|
||||
local line num
|
||||
|
||||
line="$(cat "$out_file" "$err_file" 2>/dev/null | grep -E "Final estimate:" | tail -n1 || true)"
|
||||
if [[ -z "$line" ]]; then
|
||||
echo "NA"
|
||||
return
|
||||
fi
|
||||
|
||||
num="$(echo "$line" | sed -nE 's/.*= ([0-9]+\.[0-9]+).*/\1/p')"
|
||||
if [[ -z "$num" ]]; then
|
||||
num="$(echo "$line" | grep -Eo '[0-9]+\.[0-9]+' | head -n1 || true)"
|
||||
fi
|
||||
if [[ -z "$num" ]]; then
|
||||
echo "NA"
|
||||
else
|
||||
echo "$num"
|
||||
fi
|
||||
}
|
||||
|
||||
abs_delta() {
|
||||
local a="$1"
|
||||
local b="$2"
|
||||
awk -v a="$a" -v b="$b" 'BEGIN { d = a - b; if (d < 0) d = -d; printf "%.6f", d }'
|
||||
}
|
||||
|
||||
has_token() {
|
||||
local file="$1"
|
||||
local pattern="$2"
|
||||
if grep -Eiq "$pattern" "$file"; then
|
||||
echo "yes"
|
||||
else
|
||||
echo "no"
|
||||
fi
|
||||
}
|
||||
|
||||
main_ppl() {
|
||||
LD_LIBRARY_PATH="$MAIN_LD" /mainline/build/bin/llama-perplexity "$@"
|
||||
}
|
||||
|
||||
ik_ppl() {
|
||||
LD_LIBRARY_PATH="$IK_LD" /ik/build/bin/llama-perplexity "$@"
|
||||
}
|
||||
|
||||
main_cli() {
|
||||
LD_LIBRARY_PATH="$MAIN_LD" /mainline/build/bin/llama-cli "$@"
|
||||
}
|
||||
|
||||
main_completion() {
|
||||
LD_LIBRARY_PATH="$MAIN_LD" /mainline/build/bin/llama-completion "$@"
|
||||
}
|
||||
|
||||
ik_cli() {
|
||||
LD_LIBRARY_PATH="$IK_LD" /ik/build/bin/llama-cli "$@"
|
||||
}
|
||||
|
||||
main_sweep() {
|
||||
LD_LIBRARY_PATH="$MAIN_LD" /mainline/build/bin/llama-sweep-bench "$@"
|
||||
}
|
||||
|
||||
ik_sweep() {
|
||||
LD_LIBRARY_PATH="$IK_LD" /ik/build/bin/llama-sweep-bench "$@"
|
||||
}
|
||||
|
||||
require_bin "/mainline/build/bin/llama-perplexity"
|
||||
require_bin "/mainline/build/bin/llama-cli"
|
||||
require_bin "/mainline/build/bin/llama-completion"
|
||||
require_bin "/ik/build/bin/llama-perplexity"
|
||||
require_bin "/ik/build/bin/llama-cli"
|
||||
|
||||
if [[ "$WITH_GPU" != "1" ]]; then
|
||||
export CUDA_VISIBLE_DEVICES=""
|
||||
log "GPU checks disabled (CPU-only mode)"
|
||||
else
|
||||
export CUDA_VISIBLE_DEVICES="$GPU_DEVICE"
|
||||
log "GPU checks enabled on CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
fi
|
||||
|
||||
PPL_INPUT="/out/ppl_input.txt"
|
||||
GEN_PROMPT="$(cat /out/gen_prompt.txt)"
|
||||
|
||||
# CPU perplexity: chunks=1 (mainline -> ik)
|
||||
run_cmd "cpu_ppl_chunks1_mainline" \
|
||||
main_ppl -m "$MODEL" -f "$PPL_INPUT" -c 256 -b 64 -ub 64 --chunks 1 --no-warmup -ngl 0 || true
|
||||
run_cmd "cpu_ppl_chunks1_ik" \
|
||||
ik_ppl -m "$MODEL" -f "$PPL_INPUT" -c 256 -b 64 -ub 64 --chunks 1 --no-warmup -ngl 0 || true
|
||||
|
||||
# CPU perplexity: chunks=2 (mainline -> ik)
|
||||
run_cmd "cpu_ppl_chunks2_mainline" \
|
||||
main_ppl -m "$MODEL" -f "$PPL_INPUT" -c 256 -b 64 -ub 64 --chunks 2 --no-warmup -ngl 0 || true
|
||||
run_cmd "cpu_ppl_chunks2_ik" \
|
||||
ik_ppl -m "$MODEL" -f "$PPL_INPUT" -c 256 -b 64 -ub 64 --chunks 2 --no-warmup -ngl 0 || true
|
||||
|
||||
# CPU short generation smoke quality (mainline -> ik)
|
||||
run_cmd "cpu_gen_mainline" \
|
||||
main_completion -m "$MODEL" --cpu-moe -ngl 0 -c 512 -n 64 --seed 123 --temp 0 --top-k 1 --simple-io --no-display-prompt -p "$GEN_PROMPT" || true
|
||||
run_cmd "cpu_gen_ik" \
|
||||
ik_cli -m "$MODEL" --cpu-moe -ngl 0 -c 512 -n 64 --seed 123 --temp 0 --top-k 1 --simple-io --no-display-prompt -p "$GEN_PROMPT" || true
|
||||
|
||||
if [[ "$WITH_GPU" == "1" ]]; then
|
||||
# CUDA sanity perplexity: chunks=1 (mainline -> ik)
|
||||
run_cmd "gpu_ppl_chunks1_mainline" \
|
||||
main_ppl -m "$MODEL" -f "$PPL_INPUT" -c 256 -b 64 -ub 64 --chunks 1 --no-warmup -ngl 1 || true
|
||||
run_cmd "gpu_ppl_chunks1_ik" \
|
||||
ik_ppl -m "$MODEL" -f "$PPL_INPUT" -c 256 -b 64 -ub 64 --chunks 1 --no-warmup -ngl 1 || true
|
||||
|
||||
# Quick sweep sanity (mainline -> ik)
|
||||
if [[ -x /mainline/build/bin/llama-sweep-bench ]]; then
|
||||
run_cmd "gpu_sweep_mainline" \
|
||||
main_sweep -m "$MODEL" --cpu-moe -ngl 999 -c "$SWEEP_CTX" -b 1024 -ub 128 -n "$SWEEP_N" -ctk f16 -ctv f16 || true
|
||||
else
|
||||
printf "%s\tSKIP\t0\tNA\tNA\tNA\tNA\tNA\tNA\n" "gpu_sweep_mainline" >> "$STATUS_FILE"
|
||||
log "SKIP: gpu_sweep_mainline (missing /mainline/build/bin/llama-sweep-bench)"
|
||||
fi
|
||||
if [[ -x /ik/build/bin/llama-sweep-bench ]]; then
|
||||
run_cmd "gpu_sweep_ik" \
|
||||
ik_sweep -m "$MODEL" --cpu-moe -ngl 999 -c "$SWEEP_CTX" -b 1024 -ub 128 -n "$SWEEP_N" -ctk f16 -ctv f16 || true
|
||||
else
|
||||
printf "%s\tSKIP\t0\tNA\tNA\tNA\tNA\tNA\tNA\n" "gpu_sweep_ik" >> "$STATUS_FILE"
|
||||
log "SKIP: gpu_sweep_ik (missing /ik/build/bin/llama-sweep-bench)"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Aggregate summary
|
||||
cpu_c1_main="$(extract_ppl /out/cpu_ppl_chunks1_mainline.out /out/cpu_ppl_chunks1_mainline.err)"
|
||||
cpu_c1_ik="$(extract_ppl /out/cpu_ppl_chunks1_ik.out /out/cpu_ppl_chunks1_ik.err)"
|
||||
cpu_c2_main="$(extract_ppl /out/cpu_ppl_chunks2_mainline.out /out/cpu_ppl_chunks2_mainline.err)"
|
||||
cpu_c2_ik="$(extract_ppl /out/cpu_ppl_chunks2_ik.out /out/cpu_ppl_chunks2_ik.err)"
|
||||
|
||||
cpu_c1_delta="NA"
|
||||
cpu_c2_delta="NA"
|
||||
if [[ "$cpu_c1_main" != "NA" && "$cpu_c1_ik" != "NA" ]]; then
|
||||
cpu_c1_delta="$(abs_delta "$cpu_c1_main" "$cpu_c1_ik")"
|
||||
fi
|
||||
if [[ "$cpu_c2_main" != "NA" && "$cpu_c2_ik" != "NA" ]]; then
|
||||
cpu_c2_delta="$(abs_delta "$cpu_c2_main" "$cpu_c2_ik")"
|
||||
fi
|
||||
|
||||
main_has_fib="$(has_token /out/cpu_gen_mainline.out 'fibonacci|fibs|fib')"
|
||||
ik_has_fib="$(has_token /out/cpu_gen_ik.out 'fibonacci|fibs|fib')"
|
||||
main_has_complexity="$(has_token /out/cpu_gen_mainline.out 'complexity|O\(')"
|
||||
ik_has_complexity="$(has_token /out/cpu_gen_ik.out 'complexity|O\(')"
|
||||
|
||||
{
|
||||
echo "# Qwen3Next Eval Summary"
|
||||
echo
|
||||
echo "Mode: $( [[ "$WITH_GPU" == "1" ]] && echo "CPU+GPU" || echo "CPU-only" )"
|
||||
echo "- Sweep config: c=\`$SWEEP_CTX\`, n=\`$SWEEP_N\`"
|
||||
echo
|
||||
echo "## CPU Perplexity"
|
||||
echo "- chunks=1 mainline: \`$cpu_c1_main\`"
|
||||
echo "- chunks=1 ik: \`$cpu_c1_ik\`"
|
||||
echo "- chunks=1 |delta|: \`$cpu_c1_delta\`"
|
||||
echo "- chunks=2 mainline: \`$cpu_c2_main\`"
|
||||
echo "- chunks=2 ik: \`$cpu_c2_ik\`"
|
||||
echo "- chunks=2 |delta|: \`$cpu_c2_delta\`"
|
||||
echo
|
||||
echo "## CPU Short Generation Smoke"
|
||||
echo "- mainline has Fibonacci token(s): \`$main_has_fib\`"
|
||||
echo "- ik has Fibonacci token(s): \`$ik_has_fib\`"
|
||||
echo "- mainline has complexity token(s): \`$main_has_complexity\`"
|
||||
echo "- ik has complexity token(s): \`$ik_has_complexity\`"
|
||||
echo
|
||||
echo "## Command Status + Memory"
|
||||
echo '```'
|
||||
cat "$STATUS_FILE"
|
||||
echo '```'
|
||||
echo
|
||||
echo "## First Non-empty Lines (Generation)"
|
||||
echo "### mainline"
|
||||
awk 'NF { print; c++; if (c == 20) exit }' /out/cpu_gen_mainline.out
|
||||
echo
|
||||
echo "### ik"
|
||||
awk 'NF { print; c++; if (c == 20) exit }' /out/cpu_gen_ik.out
|
||||
} > /out/SUMMARY.md
|
||||
|
||||
log "Summary written to /out/SUMMARY.md"
|
||||
BASH
|
||||
|
||||
chmod +x "${out_dir}/run_inside.sh"
|
||||
|
||||
docker_cmd=(
|
||||
docker run --rm
|
||||
-e WITH_GPU="${WITH_GPU}"
|
||||
-e GPU_DEVICE="${GPU_DEVICE}"
|
||||
-e SWEEP_CTX="${SWEEP_CTX}"
|
||||
-e SWEEP_N="${SWEEP_N}"
|
||||
-v "${MAIN_REPO}:/mainline"
|
||||
-v "${IK_REPO}:/ik"
|
||||
-v "${MODEL_HOST}:/model.gguf:ro"
|
||||
-v "${out_dir}:/out"
|
||||
)
|
||||
|
||||
if [[ "$WITH_GPU" -eq 1 ]]; then
|
||||
docker_cmd+=(--gpus all)
|
||||
fi
|
||||
|
||||
docker_cmd+=("${IMAGE}" /bin/bash /out/run_inside.sh)
|
||||
|
||||
echo "Running eval in container: ${IMAGE}"
|
||||
echo "Output directory: ${out_dir}"
|
||||
"${docker_cmd[@]}"
|
||||
|
||||
echo
|
||||
echo "Done. Summary:"
|
||||
echo " ${out_dir}/SUMMARY.md"
|
||||
echo "Raw logs:"
|
||||
echo " ${out_dir}/*.out"
|
||||
echo " ${out_dir}/*.err"
|
||||
@@ -1861,6 +1861,19 @@ static bool llm_load_tensors(
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
splits[i] = llama_get_device_memory(model, model.devices[i]);
|
||||
}
|
||||
|
||||
// For Qwen3Next on heterogeneous dual-GPU setups, pure free-memory split tends to
|
||||
// over-assign layers to the secondary GPU and hurts decode throughput.
|
||||
if (model.arch == LLM_ARCH_QWEN3NEXT && split_mode == LLAMA_SPLIT_MODE_LAYER && device_count == 2 && splits[0] >= splits[1]) {
|
||||
const float split_sum = splits[0] + splits[1];
|
||||
if (split_sum > 0.0f) {
|
||||
const float primary_share = splits[0] / split_sum;
|
||||
if (primary_share < 0.75f) {
|
||||
splits[0] = 0.75f;
|
||||
splits[1] = 0.25f;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::copy(tensor_split, tensor_split + device_count, splits.begin());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user