mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
WIP: add Q8_0 and BF16 as possible reduce types
Does not work - there is a big somewhere
This commit is contained in:
@@ -1459,11 +1459,18 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
return true;
|
||||
}
|
||||
if (arg == "-smf16" || arg == "--split-mode-f16") {
|
||||
params.split_mode_f16 = true;
|
||||
params.reduce_type = "f16";
|
||||
//params.split_mode_f16 = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-smf32" || arg == "--split-mode-f32") {
|
||||
params.split_mode_f16 = false;
|
||||
params.reduce_type = "f32";
|
||||
//params.split_mode_f16 = false;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-grt" || arg == "--graph-reduce-type") {
|
||||
CHECK_ARG
|
||||
params.reduce_type = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "--numa") {
|
||||
@@ -2154,8 +2161,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv});
|
||||
options.push_back({ "*", "-muge, --merge-up-gate-experts,","merge ffn_up/gate_exps (default: %d)", params.merge_up_gate_exps});
|
||||
options.push_back({ "*", "-khad, --k-cache-hadamard,", "Use Hadamard transform for K-cache (default: %d)", params.k_cache_hadamard});
|
||||
options.push_back({ "*", "-smf16, --split-mode-f16,", "Use f16 for data exchange between GPUs (default: %d)", params.split_mode_f16});
|
||||
options.push_back({ "*", "-smf32, --split-mode-f32,", "Use f32 for data exchange between GPUs (default: %d)", !params.split_mode_f16});
|
||||
options.push_back({ "*", "-smf16, --split-mode-f16,", "Use f16 for data exchange between GPUs (default: %d)", true});
|
||||
options.push_back({ "*", "-smf32, --split-mode-f32,", "Use f32 for data exchange between GPUs (default: %d)", false});
|
||||
options.push_back({ "*", "-grt, --graph-reduce-type", "Type for data exchange between GPUs (default: %s)", "f32"});
|
||||
options.push_back({ "*", "-smgs, --split-mode-graph-scheduling,", "Force Split Mode Graph Scheduling (default: %d)", params.split_mode_graph_scheduling});
|
||||
options.push_back({ "*", "-sas, ==scheduler_async,", "Async evaluation of compute graphs: %d)", params.scheduler_async});
|
||||
options.push_back({ "*", "-vq, --validate-quants", "validate quantized data while loading the model (default: %d)", params.validate_quants});
|
||||
@@ -3148,6 +3156,22 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||
throw std::runtime_error("Invalid cache type: " + s);
|
||||
}
|
||||
|
||||
static ggml_type ggml_type_from_str(const std::string & s) {
|
||||
if (s == "f32") {
|
||||
return GGML_TYPE_F32;
|
||||
}
|
||||
if (s == "f16") {
|
||||
return GGML_TYPE_F16;
|
||||
}
|
||||
if (s == "bf16") {
|
||||
return GGML_TYPE_BF16;
|
||||
}
|
||||
if (s == "q8_0") {
|
||||
return GGML_TYPE_Q8_0;
|
||||
}
|
||||
throw std::runtime_error("Invalid graph reduce type: " + s);
|
||||
}
|
||||
|
||||
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
|
||||
auto cparams = llama_context_default_params();
|
||||
int n_batch = params.n_batch;
|
||||
@@ -3194,7 +3218,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||
cparams.graph_reuse = params.graph_reuse;
|
||||
cparams.k_cache_hadamard = params.k_cache_hadamard;
|
||||
cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling;
|
||||
cparams.split_mode_f16 = params.split_mode_f16;
|
||||
//cparams.split_mode_f16 = params.split_mode_f16;
|
||||
cparams.scheduler_async = params.scheduler_async;
|
||||
cparams.min_experts = params.min_experts;
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
@@ -3203,6 +3227,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||
cparams.type_reduce = ggml_type_from_str(params.reduce_type);
|
||||
|
||||
if (!params.offload_policy.empty()) cparams.offload_policy = (void *)¶ms.offload_policy;
|
||||
if (!params.cuda_params.empty()) cparams.cuda_params = (void *)params.cuda_params.data();
|
||||
@@ -4180,7 +4205,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "graph_reuse: %s # default: false\n", params.graph_reuse ? "true" : "false");
|
||||
fprintf(stream, "k_cache_hadamard: %s # default: false\n", params.k_cache_hadamard ? "true" : "false");
|
||||
fprintf(stream, "split_mode_graph_scheduling: %s # default: false\n", params.split_mode_graph_scheduling ? "true" : "false");
|
||||
fprintf(stream, "split_mode_f16: %s # default: true\n", params.split_mode_f16 ? "true" : "false");
|
||||
//fprintf(stream, "split_mode_f16: %s # default: true\n", params.split_mode_f16 ? "true" : "false");
|
||||
fprintf(stream, "reduce_type: %s # default f16\n", params.reduce_type.c_str());
|
||||
fprintf(stream, "scheduler_async: %s # default: false\n", params.scheduler_async ? "true" : "false");
|
||||
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
|
||||
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
||||
|
||||
@@ -290,7 +290,7 @@ struct gpt_params {
|
||||
bool merge_up_gate_exps= false; // if true, merge ffn_up_exps and ffn_gate_exps into a single, contiguous tensor
|
||||
bool k_cache_hadamard = false; // if true, use Hadamard transform for the K-cache (only makes sense with quantized cache)
|
||||
bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling
|
||||
bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops
|
||||
//bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops
|
||||
bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph
|
||||
|
||||
std::string cache_type_k = "f16"; // KV cache data type for the K
|
||||
@@ -298,6 +298,8 @@ struct gpt_params {
|
||||
std::string cache_type_k_draft = ""; // KV cache data type for K for the draft model
|
||||
std::string cache_type_v_draft = ""; // KV cache data type for V for the draft model
|
||||
|
||||
std::string reduce_type = "f16";
|
||||
|
||||
// multimodal models (see examples/mtmd)
|
||||
model_paths mmproj;
|
||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||
|
||||
@@ -43,10 +43,20 @@ static __global__ void fused_norm_f32(const T * x, const float * c, float * dst,
|
||||
|
||||
float2 mean_var = make_float2(0.f, 0.f);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = (float)x[row*ncols + col];
|
||||
mean_var.x += xi;
|
||||
mean_var.y += xi * xi;
|
||||
if constexpr (std::is_same_v<T, block_q8_0>) {
|
||||
static_assert(block_size % QK8_0 == 0);
|
||||
auto xr = x + (row*ncols)/QK8_0;
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = (float)xr[col / QK8_0].d * xr[col / QK8_0].qs[col % QK8_0];
|
||||
mean_var.x += xi;
|
||||
mean_var.y += xi * xi;
|
||||
}
|
||||
} else {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = (float)x[row*ncols + col];
|
||||
mean_var.x += xi;
|
||||
mean_var.y += xi * xi;
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
@@ -67,8 +77,16 @@ static __global__ void fused_norm_f32(const T * x, const float * c, float * dst,
|
||||
const float var = mean_var.y / ncols - mean * mean;
|
||||
const float inv_std = rsqrtf(var + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = (T)(((float)x[row*ncols + col] - mean) * inv_std * c[col]);
|
||||
if constexpr (std::is_same_v<T, block_q8_0>) {
|
||||
static_assert(block_size % QK8_0 == 0);
|
||||
auto xr = x + (row*ncols)/QK8_0;
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = ((float)xr[col/QK8_0].d*xr[col/QK8_0].qs[col%QK8_0] - mean) * inv_std * c[col];
|
||||
}
|
||||
} else {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = ((float)x[row*ncols + col] - mean) * inv_std * c[col];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,9 +237,19 @@ static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, floa
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = (float)x[row*ncols + col];
|
||||
tmp += xi * xi;
|
||||
if constexpr (std::is_same_v<src_t, block_q8_0>) {
|
||||
static_assert(block_size % QK8_0 == 0);
|
||||
auto xr = x + (row*ncols)/QK8_0;
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = (float)xr[col / QK8_0].d * xr[col / QK8_0].qs[col % QK8_0];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
} else {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = (float)x[row*ncols + col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
@@ -241,8 +269,15 @@ static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, floa
|
||||
const float mean = tmp / ncols;
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = scale * y[col] * (float)x[row*ncols + col];
|
||||
if constexpr (std::is_same_v<src_t, block_q8_0>) {
|
||||
auto xr = x + (row*ncols)/QK8_0;
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = scale * y[col] * (float)xr[col / QK8_0].d * xr[col / QK8_0].qs[col % QK8_0];
|
||||
}
|
||||
} else {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = scale * y[col] * (float)x[row*ncols + col];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -496,7 +531,8 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 ||
|
||||
(ggml_is_contiguous(src0) && src0->type == GGML_TYPE_Q8_0));
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
|
||||
@@ -511,8 +547,12 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
|
||||
} else {
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0) {
|
||||
fused_rms_norm_f32_cuda((const block_q8_0 *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
|
||||
} else {
|
||||
fused_rms_norm_f32_cuda((const nv_bfloat16 *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream);
|
||||
}
|
||||
} else {
|
||||
if (is_norm) {
|
||||
@@ -525,6 +565,8 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
auto s03 = src0->nb[3] / ts0;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16) {
|
||||
fused_rms_norm_f32_nc_cuda((const nv_bfloat16 *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
|
||||
} else {
|
||||
fused_rms_norm_f32_nc_cuda((const half *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//
|
||||
|
||||
#include "reduce.cuh"
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include <chrono>
|
||||
|
||||
@@ -16,6 +17,23 @@ static __global__ void k_add(int nelem, const T * src, T * dst) {
|
||||
dst[i] += src[i];
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void k_add(int nelem, const block_q8_0 * src, block_q8_0 * dst) {
|
||||
int i = blockIdx.x*block_size + threadIdx.x;
|
||||
if (i >= nelem) return;
|
||||
int ib = i / QK8_0;
|
||||
int iq = i % QK8_0;
|
||||
float x = (float)src[ib].d * src[ib].qs[iq] + (float)dst[ib].d * dst[ib].qs[iq];
|
||||
float ax = fabsf(x);
|
||||
float max = warp_reduce_max(ax);
|
||||
float d = max / 127;
|
||||
float id = d > 0 ? 1/d : 0;
|
||||
dst[ib].qs[iq] = roundf(x * id);
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
dst[ib].d = (half)d;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size>
|
||||
static __global__ void k_add_sym(int nelem, T * src, T * dst) {
|
||||
int i = blockIdx.x*block_size + threadIdx.x;
|
||||
@@ -68,7 +86,8 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
GGML_ASSERT(op == GGML_OP_ADD);
|
||||
int nreduce = dst->op_params[1];
|
||||
int nhave = dst->op_params[2];
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32 ||
|
||||
dst->type == GGML_TYPE_Q8_0 || dst->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(nhave >=2 && nhave <= nreduce);
|
||||
if (dst->op_params[3] == 1) {
|
||||
@@ -82,10 +101,10 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
// It does not work at all if not all GPUs participate in the reduce op, and we
|
||||
// get suboptimal prompt processing performance when we have more than 2 GPUs.
|
||||
// Hence, if enabled, we use NCCL only for the cases where it works and performs well.
|
||||
if (info.have_nccl && nhave == nreduce && (nhave == 2 || dst->ne[1] < 32)) {
|
||||
if (false && info.have_nccl && dst->type != GGML_TYPE_Q8_0 && nhave == nreduce && (nhave == 2 || dst->ne[1] < 32)) {
|
||||
GGML_ASSERT(info.have_nccl);
|
||||
GGML_ASSERT(info.device_count == nreduce);
|
||||
auto data_type = dst->type == GGML_TYPE_F32 ? ncclFloat : ncclHalf;
|
||||
auto data_type = dst->type == GGML_TYPE_F32 ? ncclFloat : dst->type == GGML_TYPE_BF16 ? ncclBfloat16 : ncclHalf;
|
||||
ncclGroupStart();
|
||||
for (int i = 0; i < nreduce; ++i) {
|
||||
ggml_cuda_set_device(i);
|
||||
@@ -198,13 +217,25 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
//
|
||||
if (dst->ne[1] >= 32) {
|
||||
auto nelem = ggml_nelements(dst);
|
||||
auto elem_size = ggml_element_size(dst);
|
||||
auto nelem_per_device = (nelem + nhave - 1)/nhave;
|
||||
auto required_size = nelem_per_device*elem_size;
|
||||
auto tt = ggml_internal_get_type_traits(dst->type);
|
||||
GGML_ASSERT(nelem % tt.blck_size == 0);
|
||||
auto nblocks = nelem / tt.blck_size;
|
||||
auto nblocks_per_device = (nblocks + nhave - 1)/nhave;
|
||||
auto nelem_per_device = nblocks_per_device * tt.blck_size;
|
||||
auto size_per_device = nblocks_per_device * tt.type_size;
|
||||
//size_t nelem_per_device, required_size;
|
||||
//if (dst->type == GGML_TYPE_Q8_0) {
|
||||
// GGML_ASSERT(nelem % QK8_0 == 0);
|
||||
// nelem_per_device = QK8_0*((nelem/QK8_0 + nhave - 1)/nhave);
|
||||
// required_size nelem_per_device/QK8_0 * sizeof(ggml_block_q8_0);
|
||||
//}
|
||||
//auto elem_size = ggml_element_size(dst);
|
||||
//auto nelem_per_device = (nelem + nhave - 1)/nhave;
|
||||
//auto required_size = nelem_per_device*elem_size;
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
auto this_ctx = info.all_ctx[i];
|
||||
if (!this_ctx->copy_event || !this_ctx->compute_event || required_size > this_ctx->copy_size) {
|
||||
if (!this_ctx->copy_event || !this_ctx->compute_event || size_per_device > this_ctx->copy_size) {
|
||||
ggml_cuda_set_device(this_ctx->device);
|
||||
if (!this_ctx->copy_event) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&this_ctx->copy_event, cudaEventDisableTiming));
|
||||
@@ -212,12 +243,12 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
if (!this_ctx->compute_event) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&this_ctx->compute_event, cudaEventDisableTiming));
|
||||
}
|
||||
if (required_size > this_ctx->copy_size) {
|
||||
if (size_per_device > this_ctx->copy_size) {
|
||||
if (this_ctx->copy_buffer) {
|
||||
CUDA_CHECK(cudaFree(this_ctx->copy_buffer));
|
||||
}
|
||||
CUDA_CHECK(ggml_cuda_device_malloc(&this_ctx->copy_buffer, required_size, this_ctx->device));
|
||||
this_ctx->copy_size = required_size;
|
||||
CUDA_CHECK(ggml_cuda_device_malloc(&this_ctx->copy_buffer, size_per_device, this_ctx->device));
|
||||
this_ctx->copy_size = size_per_device;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -227,13 +258,14 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
int i = idx[ii];
|
||||
int peer = idx[(ii+1)%nhave];
|
||||
auto this_nelem = std::min(nelem_per_device, nelem - ichunk*nelem_per_device);
|
||||
auto this_size = (this_nelem / tt.blck_size) * tt.type_size;
|
||||
ggml_cuda_set_device(info.all_ctx[peer]->device);
|
||||
if (stage > 0) {
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[peer]->stream(), info.all_ctx[i]->compute_event, 0));
|
||||
}
|
||||
CUDA_CHECK(cudaMemcpyPeerAsync(info.all_ctx[i]->copy_buffer, info.all_ctx[i]->device,
|
||||
(const char *)dst->src[peer]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[peer]->device,
|
||||
this_nelem*elem_size, info.all_ctx[peer]->stream()));
|
||||
(const char *)dst->src[peer]->data + ichunk*size_per_device, info.all_ctx[peer]->device,
|
||||
this_size, info.all_ctx[peer]->stream()));
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[peer]->copy_event, info.all_ctx[peer]->stream()));
|
||||
ichunk = (ichunk + 1)%nhave;
|
||||
}
|
||||
@@ -248,6 +280,13 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
k_add<half, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(this_nelem,
|
||||
(const half *)info.all_ctx[i]->copy_buffer, (half *)dst->src[i]->data + ichunk*nelem_per_device);
|
||||
} else if (dst->type == GGML_TYPE_Q8_0) {
|
||||
k_add<CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(this_nelem,
|
||||
(const block_q8_0 *)info.all_ctx[i]->copy_buffer, (block_q8_0 *)dst->src[i]->data + ichunk*nelem_per_device/tt.blck_size);
|
||||
} else if (dst->type == GGML_TYPE_BF16) {
|
||||
k_add<nv_bfloat16, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(
|
||||
this_nelem, (const nv_bfloat16 *)info.all_ctx[i]->copy_buffer,
|
||||
(nv_bfloat16 *)dst->src[i]->data + ichunk*nelem_per_device);
|
||||
} else {
|
||||
k_add<float, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(this_nelem,
|
||||
(const float *)info.all_ctx[i]->copy_buffer, (float *)dst->src[i]->data + ichunk*nelem_per_device);
|
||||
@@ -262,13 +301,14 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
int i = idx[ii];
|
||||
int peer = idx[(ii+1)%nhave];
|
||||
auto this_nelem = std::min(nelem_per_device, nelem - ichunk*nelem_per_device);
|
||||
auto this_size = (this_nelem / tt.blck_size) * tt.type_size;
|
||||
ggml_cuda_set_device(info.all_ctx[peer]->device);
|
||||
if (stage == 0) {
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[peer]->stream(), info.all_ctx[i]->compute_event, 0));
|
||||
}
|
||||
CUDA_CHECK(cudaMemcpyPeerAsync((char *)dst->src[i]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[i]->device,
|
||||
(const char *)dst->src[peer]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[peer]->device,
|
||||
this_nelem*elem_size, info.all_ctx[peer]->stream()));
|
||||
CUDA_CHECK(cudaMemcpyPeerAsync((char *)dst->src[i]->data + ichunk*size_per_device, info.all_ctx[i]->device,
|
||||
(const char *)dst->src[peer]->data + ichunk*size_per_device, info.all_ctx[peer]->device,
|
||||
this_size, info.all_ctx[peer]->stream()));
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[peer]->copy_event, info.all_ctx[peer]->stream()));
|
||||
//ggml_cuda_set_device(info.all_ctx[i]->device);
|
||||
//CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[peer]->copy_event, 0));
|
||||
@@ -351,6 +391,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
return;
|
||||
}
|
||||
if (dst->ne[1] < 32 && ctx.p2p_enabled) {
|
||||
GGML_ASSERT(dst->type != GGML_TYPE_Q8_0);
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
GGML_ASSERT(dst->src[i]->type == dst->type);
|
||||
@@ -464,6 +505,12 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
|
||||
CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), info.all_ctx[i]->copy_event, 0));
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
k_add<half, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, ctx.stream()>>>(nelem, (const half *)ptr, (half *)dst->data);
|
||||
} else if (dst->type == GGML_TYPE_BF16) {
|
||||
k_add<nv_bfloat16, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, ctx.stream()>>>(nelem,
|
||||
(const nv_bfloat16*)ptr, (nv_bfloat16 *)dst->data);
|
||||
} else if (dst->type == GGML_TYPE_Q8_0) {
|
||||
k_add<CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, ctx.stream()>>>(nelem, (const block_q8_0 *)ptr,
|
||||
(block_q8_0 *)dst->data);
|
||||
} else {
|
||||
k_add<float, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, ctx.stream()>>>(nelem, (const float *)ptr, (float *)dst->data);
|
||||
}
|
||||
|
||||
@@ -426,6 +426,7 @@ extern "C" {
|
||||
|
||||
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
||||
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
||||
enum ggml_type type_reduce; // data type for reduce operations
|
||||
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
@@ -445,7 +446,7 @@ extern "C" {
|
||||
bool only_active_experts;
|
||||
bool k_cache_hadamard; // if true, apply Hadamard transfrom to K-cache
|
||||
bool split_mode_graph_scheduling; // if true, force split mode graph scheduling
|
||||
bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs
|
||||
//bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs
|
||||
bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads
|
||||
|
||||
// Abort callback
|
||||
|
||||
@@ -633,6 +633,27 @@ static ggml_tensor * get_input_tensor_sm_graph(ggml_tensor * input, int id) {
|
||||
return cur;
|
||||
}
|
||||
|
||||
static inline ggml_tensor * do_split_norm(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * the_norm, const llama_hparams & hparams,
|
||||
const llm_build_cb & cb, int id, int il_cb, bool is_norm) {
|
||||
if (the_norm && the_norm->extra) {
|
||||
auto norm = (ggml_split_tensor_t *)the_norm->extra;
|
||||
GGML_ASSERT(norm->splits[id]);
|
||||
//if (cur->type != GGML_TYPE_F16 && cur->type != GGML_TYPE_F32) {
|
||||
// cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
//}
|
||||
if (is_norm) {
|
||||
cur = ggml_fused_norm(ctx, cur, norm->splits[id], hparams.f_norm_eps);
|
||||
} else {
|
||||
cur = llm_build_context::llm_build_norm(ctx, cur, hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il_cb);
|
||||
}
|
||||
cb(cur, "inp_normed", il_cb);
|
||||
}
|
||||
if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
ggml_context * ctx,
|
||||
llama_context & lctx,
|
||||
@@ -673,19 +694,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
GGML_ASSERT((!split_u && !split_g && !split_d) || (split_u && split_g && split_d));
|
||||
if (!split_u) continue;
|
||||
auto cur = get_input_tensor_sm_graph(input, id);
|
||||
if (ffn_norm && ffn_norm->extra) {
|
||||
auto norm = (ggml_split_tensor_t *)ffn_norm->extra;
|
||||
GGML_ASSERT(norm->splits[id]);
|
||||
if (is_norm) {
|
||||
cur = ggml_fused_norm(ctx, cur, norm->splits[id], lctx.model.hparams.f_norm_eps);
|
||||
} else {
|
||||
cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
|
||||
}
|
||||
cb(cur, "ffn_inp_normed", il_cb);
|
||||
}
|
||||
else if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
}
|
||||
cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, is_norm);
|
||||
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op);
|
||||
cb(cur, "ffn_up_gate", il_cb);
|
||||
cur = llm_build_lora_mm(lctx, ctx, split_d, cur);
|
||||
@@ -694,8 +703,8 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
|
||||
if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, cur, lctx.cparams.reduce_type);
|
||||
}
|
||||
if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) {
|
||||
// When the reduce op is turned off via op_params[3] == 1, we need to add each src
|
||||
@@ -1205,8 +1214,8 @@ llm_expert_gating_func_type gating_op,
|
||||
split_down_shexp->splits[id], split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr,
|
||||
nullptr, type_op_shexp, LLM_FFN_PAR, cb, il);
|
||||
cb(shared_out, "ffn_shexp_out", il_cb);
|
||||
if (shared_out->ne[1] > 32 && lctx.cparams.split_mode_f16) {
|
||||
shared_out = ggml_cast(ctx, shared_out, GGML_TYPE_F16);
|
||||
if (shared_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
|
||||
shared_out = ggml_cast(ctx, shared_out, lctx.cparams.reduce_type);
|
||||
}
|
||||
results.push_back(shared_out);
|
||||
}
|
||||
@@ -1222,8 +1231,8 @@ llm_expert_gating_func_type gating_op,
|
||||
cb(cur, "ffn_shared_combined", il);
|
||||
}
|
||||
}
|
||||
if (routed_out->ne[1] > 32 && lctx.cparams.split_mode_f16) {
|
||||
auto routed_out_f16 = ggml_cast(ctx, routed_out, GGML_TYPE_F16);
|
||||
if (routed_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
|
||||
auto routed_out_f16 = ggml_cast(ctx, routed_out, lctx.cparams.reduce_type);
|
||||
cur = ggml_add(ctx, routed_out_f16, cur);
|
||||
} else {
|
||||
cur = ggml_add(ctx, routed_out, cur);
|
||||
@@ -1269,15 +1278,16 @@ llm_expert_gating_func_type gating_op,
|
||||
if (!split_up_exps->splits[id]) continue;
|
||||
int il_cb = 1000*(id + 1) + il;
|
||||
auto cur = get_input_tensor_sm_graph(input, id);
|
||||
if (ffn_norm) {
|
||||
auto split_ffn_norm = (ggml_split_tensor_t *)ffn_norm->extra;
|
||||
GGML_ASSERT(split_ffn_norm && split_ffn_norm->n_device == split_up_exps->n_device);
|
||||
cur = llm_build_norm(ctx, cur, lctx.model.hparams, split_ffn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_inp_normed", il_cb);
|
||||
}
|
||||
if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
}
|
||||
cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, false);
|
||||
//if (ffn_norm) {
|
||||
// auto split_ffn_norm = (ggml_split_tensor_t *)ffn_norm->extra;
|
||||
// GGML_ASSERT(split_ffn_norm && split_ffn_norm->n_device == split_up_exps->n_device);
|
||||
// cur = llm_build_norm(ctx, cur, lctx.model.hparams, split_ffn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il);
|
||||
// cb(cur, "ffn_inp_normed", il_cb);
|
||||
//}
|
||||
//if (cur->type != GGML_TYPE_F32) {
|
||||
// cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
//}
|
||||
GGML_ASSERT(!split_gate_inp_b || split_gate_inp_b->splits[id]);
|
||||
GGML_ASSERT(!split_exps_down_b || split_exps_down_b->splits[id]);
|
||||
GGML_ASSERT(!split_exps_gate_b || split_exps_gate_b->splits[id]);
|
||||
@@ -1309,8 +1319,8 @@ llm_expert_gating_func_type gating_op,
|
||||
} else {
|
||||
cur = routed_out;
|
||||
}
|
||||
if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
|
||||
if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, cur, lctx.cparams.reduce_type);
|
||||
cb(cur, "ffn_out_f16", il_cb);
|
||||
}
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
@@ -9221,16 +9231,17 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
|
||||
if (!split_wq) continue;
|
||||
auto cur = get_input_tensor_sm_graph(input, id);
|
||||
if (attn_norm) {
|
||||
if (is_norm) {
|
||||
cur = ggml_fused_norm(ctx0, cur, attn_norm->splits[id], lctx.model.hparams.f_norm_eps);
|
||||
} else {
|
||||
cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
|
||||
}
|
||||
}
|
||||
if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
||||
}
|
||||
cur = do_split_norm(ctx0, cur, the_attn_norm, lctx.model.hparams, cb, id, il_cb, is_norm);
|
||||
//if (attn_norm) {
|
||||
// if (is_norm) {
|
||||
// cur = ggml_fused_norm(ctx0, cur, attn_norm->splits[id], lctx.model.hparams.f_norm_eps);
|
||||
// } else {
|
||||
// cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
|
||||
// }
|
||||
//}
|
||||
//if (cur->type != GGML_TYPE_F32) {
|
||||
// cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
||||
//}
|
||||
auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ?
|
||||
((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr;
|
||||
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?
|
||||
@@ -9368,8 +9379,8 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
cb(cur, "kqv_wo_biased", il_cb);
|
||||
output_bias_added = true;
|
||||
}
|
||||
if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) {
|
||||
cur = ggml_cast(ctx0, cur, GGML_TYPE_F16);
|
||||
if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx0, cur, lctx.cparams.reduce_type);
|
||||
}
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
attn[id] = cur;
|
||||
|
||||
@@ -41,11 +41,12 @@ struct llama_cparams {
|
||||
bool graph_reuse;
|
||||
bool k_cache_hadamard;
|
||||
bool split_mode_graph_scheduling;
|
||||
bool split_mode_f16;
|
||||
//bool split_mode_f16;
|
||||
bool scheduler_async;
|
||||
int min_experts;
|
||||
float thresh_experts;
|
||||
|
||||
enum ggml_type reduce_type;
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
|
||||
@@ -4070,6 +4070,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.cb_eval_user_data =*/ nullptr,
|
||||
/*.type_k =*/ GGML_TYPE_F16,
|
||||
/*.type_v =*/ GGML_TYPE_F16,
|
||||
/*.type_reduce =*/ GGML_TYPE_F16,
|
||||
/*.logits_all =*/ false,
|
||||
/*.embeddings =*/ false,
|
||||
/*.offload_kqv =*/ true,
|
||||
@@ -4087,7 +4088,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.only_active_experts =*/ false,
|
||||
/*.k_cache_hadamard =*/ false,
|
||||
/*.split_mode_graph_scheduling =*/ false,
|
||||
/*.split_mode_f16 =*/ true,
|
||||
// /*.split_mode_f16 =*/ true,
|
||||
/*.scheduler_async =*/ false,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
@@ -4382,6 +4383,8 @@ struct llama_context * llama_new_context_with_model(
|
||||
struct llama_model * model,
|
||||
struct llama_context_params params) {
|
||||
|
||||
printf("===================================== %s: %s\n", __func__, ggml_type_name(params.type_reduce));
|
||||
|
||||
if (!model) {
|
||||
LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
|
||||
return nullptr;
|
||||
@@ -4452,12 +4455,13 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.graph_reuse = params.graph_reuse;
|
||||
cparams.k_cache_hadamard = params.k_cache_hadamard;
|
||||
cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling;
|
||||
cparams.split_mode_f16 = params.split_mode_f16;
|
||||
//cparams.split_mode_f16 = params.split_mode_f16;
|
||||
cparams.scheduler_async = params.scheduler_async;
|
||||
cparams.min_experts = params.min_experts;
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
cparams.cuda_params = params.cuda_params;
|
||||
|
||||
cparams.reduce_type = params.type_reduce;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
@@ -4527,12 +4531,19 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.mla_attn = 0;
|
||||
}
|
||||
if (model->arch == LLM_ARCH_OPENAI_MOE && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
if (cparams.split_mode_f16) {
|
||||
//if (cparams.split_mode_f16) {
|
||||
// LLAMA_LOG_WARN("=====================================================================\n");
|
||||
// LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n");
|
||||
// LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n");
|
||||
// LLAMA_LOG_WARN("=====================================================================\n");
|
||||
// cparams.split_mode_f16 = false;
|
||||
//}
|
||||
if (cparams.reduce_type == GGML_TYPE_F16) {
|
||||
LLAMA_LOG_WARN("=====================================================================\n");
|
||||
LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n");
|
||||
LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n");
|
||||
LLAMA_LOG_WARN("=====================================================================\n");
|
||||
cparams.split_mode_f16 = false;
|
||||
cparams.reduce_type = GGML_TYPE_F32;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4552,7 +4563,8 @@ struct llama_context * llama_new_context_with_model(
|
||||
LLAMA_LOG_INFO("%s: graph_reuse = %d\n", __func__, cparams.graph_reuse);
|
||||
LLAMA_LOG_INFO("%s: k_cache_hadam = %d\n", __func__, cparams.k_cache_hadamard);
|
||||
LLAMA_LOG_INFO("%s: split_mode_graph_scheduling = %d\n", __func__, cparams.split_mode_graph_scheduling);
|
||||
LLAMA_LOG_INFO("%s: split_mode_f16= %d\n", __func__, cparams.split_mode_f16);
|
||||
//LLAMA_LOG_INFO("%s: split_mode_f16= %d\n", __func__, cparams.split_mode_f16);
|
||||
LLAMA_LOG_INFO("%s: reduce_type = %s\n", __func__, ggml_type_name(cparams.reduce_type));
|
||||
LLAMA_LOG_INFO("%s: sched_async = %d\n", __func__, cparams.scheduler_async);
|
||||
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
|
||||
Reference in New Issue
Block a user