mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
CUDA: set compute parameters via command line arguments (#910)
* cuda: set compute parameters via command line arguments * Also llama-bench --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -21,7 +21,7 @@ extern "C" {
|
||||
#define GGML_CUDA_MAX_DEVICES 16
|
||||
|
||||
// backend API
|
||||
GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
|
||||
GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, const void * params);
|
||||
|
||||
GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#define IK_PRINT_TIMING 0
|
||||
|
||||
@@ -2420,7 +2421,8 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
|
||||
if (src1->ne[2] <= ctx.mmq_id_thresh*src0->ne[2] &&
|
||||
ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
|
||||
ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr);
|
||||
return false;
|
||||
}
|
||||
@@ -2685,7 +2687,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
// My original hypothesis was that it is dependent on the total/active experts ratio, but from these 3 it
|
||||
// looks like it really depends just on the total number of experts.
|
||||
// TODO: verify with more models, or perhaps make the magic constant '32' to be defined via a compile time define.
|
||||
if (src1->ne[2] <= 32*src0->ne[2] &&
|
||||
if (src1->ne[2] <= ctx.mmq_id_thresh*src0->ne[2] &&
|
||||
ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 &&
|
||||
ggml_cuda_can_use_mmq_id(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
|
||||
|
||||
@@ -3060,6 +3062,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
|
||||
auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr;
|
||||
|
||||
auto fusion = ctx.fusion;
|
||||
|
||||
//printf("%4d %s(%s)\n", i, ggml_op_name(dst->op), dst->name);
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ARGMAX:
|
||||
@@ -3084,7 +3088,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
|
||||
if (fusion && i + 2 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_ADD &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
|
||||
ggml_is_contiguous(dst->src[0]) &&
|
||||
@@ -3098,7 +3102,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
}
|
||||
else if (GGML_CUDA_FUSION && i + 1 < cgraph->n_nodes &&
|
||||
else if (fusion && i + 1 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM &&
|
||||
ggml_is_contiguous(dst->src[0]) &&
|
||||
ggml_is_contiguous(dst->src[1]) &&
|
||||
@@ -3155,7 +3159,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_relu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
if (GGML_CUDA_FUSION && i + 5 < cgraph->n_nodes &&
|
||||
if (fusion && i + 5 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
|
||||
@@ -3164,14 +3168,14 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]);
|
||||
i += 5;
|
||||
}
|
||||
else if (GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes &&
|
||||
else if (fusion && i + 4 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK &&
|
||||
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+4)) {
|
||||
cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]);
|
||||
i += 4;
|
||||
} else if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
|
||||
} else if (fusion && i + 2 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ADD && ops_are_same_device(cgraph, i, i+2)) {
|
||||
ggml_cuda_op_biased_sigmoid(ctx, cgraph->nodes[i+2]);
|
||||
@@ -3242,7 +3246,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_rms_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
if (false && GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes &&
|
||||
if (false && fusion && i + 4 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST &&
|
||||
@@ -3250,7 +3254,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_fused_rms_rope_fast(ctx, cgraph->nodes[i+3], cgraph->nodes[i+4])) {
|
||||
i += 4;
|
||||
}
|
||||
else if (false && GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes &&
|
||||
else if (false && fusion && i + 4 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_ROPE_FAST &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_FUSED_RMS_NORM &&
|
||||
@@ -3258,7 +3262,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_fused_rms_rope_fast(ctx, cgraph->nodes[i+1], cgraph->nodes[i+4])) {
|
||||
i += 4;
|
||||
}
|
||||
else if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
|
||||
else if (fusion && i + 2 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
|
||||
dst->ne[2] == 1 && cgraph->nodes[i+2]->ne[2] == 1) {
|
||||
@@ -3310,7 +3314,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_diag_mask_inf(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
if (GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes &&
|
||||
if (fusion && i + 4 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
|
||||
@@ -3333,20 +3337,20 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_rope_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE_FAST:
|
||||
if (GGML_CUDA_FUSION && i + 3 < cgraph->n_nodes &&
|
||||
if (fusion && i + 3 < cgraph->n_nodes &&
|
||||
(cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) &&
|
||||
(cgraph->nodes[i+2]->op == GGML_OP_RESHAPE || cgraph->nodes[i+2]->op == GGML_OP_VIEW) &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST &&
|
||||
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+3])) {
|
||||
i += 3;
|
||||
}
|
||||
else if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
|
||||
else if (fusion && i + 2 < cgraph->n_nodes &&
|
||||
(cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ROPE_FAST &&
|
||||
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+2])) {
|
||||
i += 2;
|
||||
}
|
||||
else if (GGML_CUDA_FUSION && i + 1 < cgraph->n_nodes &&
|
||||
else if (fusion && i + 1 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_ROPE_FAST &&
|
||||
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+1])) {
|
||||
i += 1;
|
||||
@@ -3374,7 +3378,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_pool2d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes &&
|
||||
if (fusion && i + 2 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_SCALE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_DIV &&
|
||||
cgraph->nodes[i+1]->src[0] == dst &&
|
||||
@@ -3383,7 +3387,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
}
|
||||
else if (GGML_CUDA_FUSION && i + 1 < cgraph->n_nodes &&
|
||||
else if (fusion && i + 1 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_DIV &&
|
||||
cgraph->nodes[i+1]->src[1] == dst &&
|
||||
cgraph->nodes[i+1]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+1)) {
|
||||
@@ -3394,7 +3398,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
if (GGML_CUDA_FUSION && i + 5 < cgraph->n_nodes &&
|
||||
if (fusion && i + 5 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_RESHAPE &&
|
||||
@@ -4462,7 +4466,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, gg
|
||||
}
|
||||
|
||||
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
||||
constexpr int min_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
||||
auto ctx = (const ggml_backend_cuda_context *)backend->context;
|
||||
int min_batch_size = ctx->offload_batch_size; //originally: GGML_CUDA_MIN_BATCH_OFFLOAD;
|
||||
|
||||
// Why do we want to do this? The heuristics that the batch must have more than min_batch_size tokens to be worth it
|
||||
// offloading the required model weights comes from dense models. For MoE models, the average number of tokens
|
||||
@@ -4575,7 +4580,65 @@ static ggml_guid_t ggml_backend_cuda_guid() {
|
||||
return &guid;
|
||||
}
|
||||
|
||||
GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
|
||||
struct cuda_params {
|
||||
int fusion = GGML_CUDA_FUSION;
|
||||
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
||||
int mmq_id_thresh = 32;
|
||||
};
|
||||
|
||||
static std::vector<std::string> string_split(const std::string& str, const std::string& delimiter) {
|
||||
std::vector<std::string> parts;
|
||||
size_t start = 0;
|
||||
size_t end = str.find(delimiter);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
parts.push_back(str.substr(start, end - start));
|
||||
start = end + delimiter.length();
|
||||
end = str.find(delimiter, start);
|
||||
}
|
||||
|
||||
parts.push_back(str.substr(start));
|
||||
|
||||
return parts;
|
||||
}
|
||||
|
||||
template <typename T> bool read_value(const std::string& val, T& result) {
|
||||
std::istringstream str(val);
|
||||
T tmp; str >> tmp;
|
||||
if (!str.fail()) {
|
||||
result = tmp;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static cuda_params ggml_cuda_parse_params(const char * params_string) {
|
||||
cuda_params params{};
|
||||
if (!params_string) return params;
|
||||
auto values = string_split(std::string{params_string}, ",");
|
||||
if (values.empty()) return params;
|
||||
for (auto& value : values) {
|
||||
auto parsed = string_split(value, "=");
|
||||
bool is_good = false;
|
||||
if (parsed.size() == 2) {
|
||||
if (parsed[0] == "fusion") {
|
||||
is_good = read_value(parsed[1], params.fusion);
|
||||
}
|
||||
else if (parsed[0] == "offload-batch-size") {
|
||||
is_good = read_value(parsed[1], params.offload_batch_size);
|
||||
}
|
||||
else if (parsed[0] == "mmq-id-size") {
|
||||
is_good = read_value(parsed[1], params.mmq_id_thresh);
|
||||
}
|
||||
}
|
||||
if (!is_good) {
|
||||
GGML_CUDA_LOG_WARN("%s: invalid parameter %s (%d) -> ignored\n", __func__, value.c_str(), (int)parsed.size());
|
||||
}
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] const void * param_string) {
|
||||
if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
|
||||
GGML_CUDA_LOG_ERROR("%s: invalid device %d\n", __func__, device);
|
||||
return nullptr;
|
||||
@@ -4593,6 +4656,22 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
|
||||
/* .context = */ ctx
|
||||
};
|
||||
|
||||
if (param_string) {
|
||||
[[maybe_unused]] auto params = ggml_cuda_parse_params((const char *)param_string);
|
||||
if (params.fusion != ctx->fusion) {
|
||||
GGML_CUDA_LOG_INFO(" =========================== %s: setting fusion to %d\n", __func__, params.fusion);
|
||||
ctx->fusion = params.fusion;
|
||||
}
|
||||
if (params.offload_batch_size != ctx->offload_batch_size) {
|
||||
GGML_CUDA_LOG_INFO(" =========================== %s: setting offload_batch_size to %d\n", __func__, params.offload_batch_size);
|
||||
ctx->offload_batch_size = params.offload_batch_size;
|
||||
}
|
||||
if (params.mmq_id_thresh != ctx->mmq_id_thresh) {
|
||||
GGML_CUDA_LOG_INFO(" =========================== %s: setting mmq_id_thresh to %d\n", __func__, params.mmq_id_thresh);
|
||||
ctx->mmq_id_thresh = params.mmq_id_thresh;
|
||||
}
|
||||
}
|
||||
|
||||
return cuda_backend;
|
||||
}
|
||||
|
||||
@@ -4651,7 +4730,7 @@ GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
|
||||
|
||||
// backend registry
|
||||
GGML_CALL static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) {
|
||||
ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data);
|
||||
ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data, nullptr);
|
||||
return cuda_backend;
|
||||
|
||||
GGML_UNUSED(params);
|
||||
|
||||
@@ -837,6 +837,10 @@ struct ggml_backend_cuda_context {
|
||||
|
||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
int fusion = GGML_CUDA_FUSION;
|
||||
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
||||
int mmq_id_thresh = 32;
|
||||
|
||||
explicit ggml_backend_cuda_context(int device);
|
||||
|
||||
~ggml_backend_cuda_context();
|
||||
|
||||
Reference in New Issue
Block a user