mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Fused Q and K fused_rms_norm for TG on CUDA (#882)
* Biased mmvq: minor optimization * Fusing Q and K rms_norm for TG on CUDA * Remove commented out code --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -3244,7 +3244,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_rms_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
ggml_cuda_op_fused_rms_norm(ctx, dst);
|
||||
if (i + 2 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
|
||||
dst->ne[2] == 1 && cgraph->nodes[i+2]->ne[2] == 1) {
|
||||
ggml_cuda_op_fused_rms_rms_norm(ctx, dst, cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
} else {
|
||||
ggml_cuda_op_fused_rms_norm(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
|
||||
@@ -619,3 +619,84 @@ void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx,
|
||||
fused_add_add_rms_norm_f32_cuda((const float *)add1->src[0]->data, (const float *)add1->src[1]->data, (const float *)add2->src[1]->data,
|
||||
src1_d, (float *)add2->data, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void fused_rms_rms_norm_f32(int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps,
|
||||
const char *x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
auto x_row = (const float *)(row < nrows1 ? x1 + row*nb1 : x2 + (row - nrows1)*nb2);
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x_row[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
const float mean = tmp / ncols;
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
auto dst = row < nrows1 ? y1 + row*ncols : y2 + (row - nrows1)*ncols;
|
||||
auto c = row < nrows1 ? c1 : c2;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * c[col] * x_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
static void fused_rms_rms_norm_f32_cuda(int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps,
|
||||
const char * x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
int nrows = nrows1 + nrows2;
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
fused_rms_rms_norm_f32<256><<<nrows, block_dims, 0, stream>>>(ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
fused_rms_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_fused_rms_rms_norm([[maybe_unused]] ggml_backend_cuda_context & ctx, [[maybe_unused]] ggml_tensor * rms1, [[maybe_unused]] ggml_tensor * rms2) {
|
||||
GGML_ASSERT(rms1->ne[2] == 1 && rms1->ne[3] == 1);
|
||||
GGML_ASSERT(rms2->ne[2] == 1 && rms2->ne[3] == 1);
|
||||
GGML_ASSERT(rms1->ne[0] == rms2->ne[0]);
|
||||
GGML_ASSERT(rms1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms1->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms2->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms1->src[0]->ne[0] == rms1->src[1]->ne[0]);
|
||||
GGML_ASSERT(rms2->src[0]->ne[0] == rms2->src[1]->ne[0]);
|
||||
GGML_ASSERT(ggml_nrows(rms1->src[1]) == 1);
|
||||
GGML_ASSERT(ggml_nrows(rms2->src[1]) == 1);
|
||||
GGML_ASSERT(rms1->src[1]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms2->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
float eps1, eps2;
|
||||
memcpy(&eps1, rms1->op_params, sizeof(float));
|
||||
memcpy(&eps2, rms2->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps1 == eps2);
|
||||
|
||||
fused_rms_rms_norm_f32_cuda(rms1->ne[0], rms1->ne[1], rms2->ne[1], rms1->nb[1], rms2->nb[1], eps1,
|
||||
(const char *)rms1->src[0]->data, (const char *)rms2->src[0]->data,
|
||||
(const float *)rms1->src[1]->data, (const float *)rms2->src[1]->data,
|
||||
(float *)rms1->data, (float *)rms2->data, ctx.stream());
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -11,3 +11,5 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_fused_rms_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * rms1, ggml_tensor * rms2);
|
||||
|
||||
@@ -1279,10 +1279,12 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
|
||||
if (q_norm) {
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
ggml_build_forward_expand(gf, Qcur);
|
||||
}
|
||||
if (k_norm) {
|
||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
ggml_build_forward_expand(gf, Kcur);
|
||||
}
|
||||
|
||||
return {Qcur, Kcur, Vcur};
|
||||
|
||||
@@ -2451,7 +2451,6 @@ bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias) {
|
||||
layer.wk = ml.create_tensor_as_view(ctx_split, layer.wqkv, wk_name.c_str(), { wk->ne[0], wk->ne[1] }, wq->ne[1]*wq->nb[1]);
|
||||
layer.wv = ml.create_tensor_as_view(ctx_split, layer.wqkv, wv_name.c_str(), { wv->ne[0], wv->ne[1] }, wq->ne[1]*wq->nb[1] + wk->ne[1]*wk->nb[1] );
|
||||
fused_qkv = true;
|
||||
printf("================================== Created merged qkv %s\n", layer.wqkv->name);
|
||||
if (bias) {
|
||||
auto bq_name = tn(LLM_TENSOR_ATTN_Q, "bias", i);
|
||||
auto bk_name = tn(LLM_TENSOR_ATTN_K, "bias", i);
|
||||
|
||||
Reference in New Issue
Block a user