mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-01 09:30:15 +00:00
Fuse the attention gate in Step-3.5-Flash (#1244)
* WIP * This works but is slow * Turn off the up / gate clamps for now * OK we need the clamping * Fuse the clamp (CUDA) * Fuse the clamp (CPU) * WIP * Be able to use merged q, k, v * Be able to use merged up/gate experts * Fuse the clamp (CUDA mmvq) * WIP: graph parallel for Step-3.5 * WIP * This should be it * Cleanup * Fix merge * Not working attempt to extend fused_mul_unary to the Step-3.5 case * It works now, but performance gain is very minor
This commit is contained in:
@@ -72,6 +72,38 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa
|
||||
dst[i] = g * max(-limit, min(limit, y[i]));
|
||||
}
|
||||
|
||||
static __global__ void fused_mul_silu_f32(int ne0, const float * x, const float * y, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
int row = i / ne0;
|
||||
dst[i] = x[row] * y[i] / (1.0f + expf(-x[row]));
|
||||
}
|
||||
|
||||
static __global__ void fused_mul_sigmoid_f32(int ne0, const float * x, const float * y, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
int row = i / ne0;
|
||||
dst[i] = y[i] / (1.0f + expf(-x[row]));
|
||||
}
|
||||
|
||||
static __global__ void fused_mul_silu_f32(int ne0, const float * x, const float * y, float * dst, const int k, float limit) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
int row = i / ne0;
|
||||
float g = x[row] / (1.0f + expf(-x[row]));
|
||||
g = min(g, limit);
|
||||
dst[i] = g * max(-limit, min(limit, y[i]));
|
||||
}
|
||||
|
||||
static __global__ void fused_mul_silu_f32(const float * x, float * dst, const int k, const int ne0) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
@@ -257,6 +289,20 @@ static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * ds
|
||||
}
|
||||
}
|
||||
|
||||
static void fused_mul_silu_f32_cuda(int ne0, const float * x, const float * y, float * dst, const int k, float limit, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
||||
if (limit < 1e-6f) {
|
||||
fused_mul_silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(ne0, x, y, dst, k);
|
||||
} else {
|
||||
fused_mul_silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(ne0, x, y, dst, k, limit);
|
||||
}
|
||||
}
|
||||
|
||||
static void fused_mul_sigmoid_f32_cuda(int ne0, const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
||||
fused_mul_sigmoid_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(ne0, x, y, dst, k);
|
||||
}
|
||||
|
||||
static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||
fused_mul_relu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
||||
@@ -410,7 +456,21 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
if (src1) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
||||
if (!ggml_are_same_shape(src0, src1)) {
|
||||
GGML_ASSERT(src0->ne[0] == 1 && src0->ne[1] == src1->ne[1] && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3]);
|
||||
if (op == GGML_UNARY_OP_SILU) {
|
||||
fused_mul_silu_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data,
|
||||
ggml_nelements(dst), limit, ctx.stream());
|
||||
}
|
||||
else if (op == GGML_UNARY_OP_SIGMOID) {
|
||||
fused_mul_sigmoid_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data,
|
||||
ggml_nelements(dst), ctx.stream());
|
||||
} else {
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data, limit);
|
||||
} else {
|
||||
|
||||
@@ -6472,9 +6472,20 @@ static struct ggml_tensor * ggml_fused_mul_unary_impl(
|
||||
struct ggml_tensor * b,
|
||||
enum ggml_unary_op op,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_are_same_shape(b, a));
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
if (!ggml_are_same_shape(b, a)) {
|
||||
GGML_ASSERT(a->ne[0] == 1 && a->ne[1] == b->ne[1] && a->ne[2] == b->ne[2] && a->ne[3] == b->ne[3]);
|
||||
GGML_ASSERT(op == GGML_UNARY_OP_SILU || op == GGML_UNARY_OP_SIGMOID);
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, b) : ggml_dup_tensor(ctx, b);
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
||||
result->op = GGML_OP_FUSED_MUL_UNARY;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
return result;
|
||||
}
|
||||
GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);
|
||||
//GGML_ASSERT(ggml_are_same_shape(b, a));
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
@@ -15158,17 +15169,11 @@ static void ggml_compute_forward_fused_mul_unary_f32(
|
||||
enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0];
|
||||
const float limit = *(const float *)(dst->op_params + 1);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = dst->ne[0];
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
const int nr = ggml_nrows(dst);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
@@ -15177,6 +15182,32 @@ static void ggml_compute_forward_fused_mul_unary_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (!ggml_are_same_shape(src0, src1)) {
|
||||
GGML_ASSERT(src0->ne[0] == 1 && ggml_nrows(src0) == nr);
|
||||
GGML_ASSERT(op == GGML_UNARY_OP_SILU || op == GGML_UNARY_OP_SIGMOID);
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * z = (float *) ((char *) dst->data + i1*( dst->nb[1]));
|
||||
const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1]));
|
||||
const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1]));
|
||||
float gate = op == GGML_UNARY_OP_SILU ? ggml_silu_f32(x[0]) : 1.0f/(1.0f + expf(-x[0]));
|
||||
if (limit < 1e-6f) {
|
||||
for (int i = 0; i < nc; ++i) z[i] = gate * y[i];
|
||||
} else {
|
||||
gate = MIN(gate, limit);
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
float up = MAX(-limit, MIN(limit, y[i]));
|
||||
z[i] = up * gate;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * z = (float *) ((char *) dst->data + i1*( dst->nb[1]));
|
||||
const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1]));
|
||||
|
||||
@@ -9683,13 +9683,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
GGML_ASSERT(wqkv_gate && wqkv_gate->splits[id]);
|
||||
auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate->splits[id], input_normed);
|
||||
cb(gate, "attn_gate", il_cb);
|
||||
gate = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate, "attn_gate_sigmoid", il_cb);
|
||||
int nh = split_wo->ne[0]/n_embd_head_v;
|
||||
auto attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, nh, n_tokens);
|
||||
auto gate_3d = ggml_reshape_3d(ctx0, gate, 1, nh, n_tokens);
|
||||
gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d);
|
||||
cur = ggml_mul(ctx0, attn_3d, gate_3d);
|
||||
cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID);
|
||||
cb(attn_3d, "attn_gated_3d", il_cb);
|
||||
}
|
||||
|
||||
@@ -9777,17 +9774,12 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
cb(cur, "wqkv", il);
|
||||
auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate, input_normed); // [n_head_l, n_tokens]
|
||||
cb(gate, "attn_gate", il);
|
||||
gate = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate, "attn_gate_sigmoid", il);
|
||||
// reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
|
||||
int n_head_l = hparams.n_head(il);
|
||||
ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, n_head_l, n_tokens);
|
||||
ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
|
||||
gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d);
|
||||
cb(gate_3d, "attn_gate_bcast", il);
|
||||
attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
|
||||
cb(attn_3d, "attn_gated_3d", il);
|
||||
cur = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
|
||||
auto attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, n_head_l, n_tokens);
|
||||
auto gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
|
||||
cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID);
|
||||
cb(cur, "attn_gated_3d", il);
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v * n_head_l, n_tokens);
|
||||
cb(cur, "attn_gated", il);
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||
if (model.layers[il].bo) {
|
||||
|
||||
Reference in New Issue
Block a user