mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 02:20:01 +00:00
Fused fused_rms+fused_rms+rope+rope (with -mqkv)
This commit is contained in:
@@ -3249,7 +3249,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
// for (int j = 1; j <= 6; ++j) printf(" %s(%s)\n", ggml_op_name(cgraph->nodes[i+j]->op), cgraph->nodes[i+j]->name);
|
||||
//}
|
||||
// Disabled because currently there is something wrong in the fused kernel implementation
|
||||
if (false && ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
|
||||
if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST &&
|
||||
|
||||
@@ -208,74 +208,56 @@ static __global__ void fused_rms_rope_neox_fast(const float * src0_1, const floa
|
||||
int s01_1, int s02_1, int s01_2, int s02_2, int n_dims, float eps) {
|
||||
|
||||
int i0 = 2*threadIdx.y;
|
||||
int i1 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
int i2 = blockIdx.z*blockDim.z + threadIdx.z;
|
||||
int i2 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
int i1 = blockIdx.z*blockDim.z + threadIdx.z;
|
||||
|
||||
__shared__ float s_sum[WARP_SIZE];
|
||||
const float * src0, *c;
|
||||
float * dst;
|
||||
int ne1, s01, s02;
|
||||
|
||||
src0_1 += i1*s01_1 + i2*s02_1;
|
||||
src0_2 += i1*s01_2 + i2*s02_2;
|
||||
dst_1 += ne0*(i1 + i2*ne1_1);
|
||||
dst_2 += ne0*(i1 + i2*ne1_2);
|
||||
|
||||
float norm_1 = 1, norm_2 = 1;
|
||||
if (i1 < ne1_1) {
|
||||
float sum = i0 < ne0 ? src0_1[i0]*src0_1[i0] + src0_1[i0+1]*src0_1[i0+1] : 0.0f;
|
||||
sum = warp_reduce_sum(sum);
|
||||
if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) {
|
||||
int warp_id = (i0/2) / WARP_SIZE;
|
||||
int lane_id = (i0/2) % WARP_SIZE;
|
||||
if (lane_id == 0) s_sum[warp_id] = sum;
|
||||
__syncthreads();
|
||||
sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0;
|
||||
sum = warp_reduce_sum(sum);
|
||||
}
|
||||
norm_1 = rsqrtf(sum/ne0 + eps);
|
||||
ne1 = ne1_1;
|
||||
s01 = s01_1; s02 = s02_1;
|
||||
src0 = src0_1 + i1*s01 + i2*s02;
|
||||
dst = dst_1 + ne0*(i1 + i2*ne1);
|
||||
c = c_1;
|
||||
} else {
|
||||
i1 -= ne1_1;
|
||||
ne1 = ne1_2;
|
||||
s01 = s01_2; s02 = s02_2;
|
||||
src0 = src0_2 + i1*s01 + i2*s02;
|
||||
dst = dst_2 + ne0*(i1 + i2*ne1);
|
||||
c = c_2;
|
||||
}
|
||||
if (i2 < ne1_2) {
|
||||
float sum = i0 < ne0 ? src0_2[i0]*src0_2[i0] + src0_2[i0+1]*src0_2[i0+1] : 0.0f;
|
||||
|
||||
float sum = i0 < ne0 ? src0[i0]*src0[i0] + src0[i0+1]*src0[i0+1] : 0.0f;
|
||||
sum = warp_reduce_sum(sum);
|
||||
if constexpr (CUDA_ROPE_BLOCK_SIZE > WARP_SIZE) {
|
||||
__shared__ float s_sum[WARP_SIZE];
|
||||
int warp_id = (i0/2) / WARP_SIZE;
|
||||
int lane_id = (i0/2) % WARP_SIZE;
|
||||
if (lane_id == 0) s_sum[warp_id] = sum;
|
||||
__syncthreads();
|
||||
sum = lane_id < CUDA_ROPE_BLOCK_SIZE / WARP_SIZE ? s_sum[lane_id] : 0;
|
||||
sum = warp_reduce_sum(sum);
|
||||
if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) {
|
||||
int warp_id = (i0/2) / WARP_SIZE;
|
||||
int lane_id = (i0/2) % WARP_SIZE;
|
||||
if (lane_id == 0) s_sum[warp_id] = sum;
|
||||
__syncthreads();
|
||||
sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0;
|
||||
sum = warp_reduce_sum(sum);
|
||||
}
|
||||
norm_2 = rsqrtf(sum/ne0 + eps);
|
||||
}
|
||||
float norm = rsqrtf(sum/ne0 + eps);
|
||||
|
||||
if (i0 >= ne0) return;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
if (i1 < ne1_1) {
|
||||
dst_1[i0 + 0] = norm_1*c_1[i0 + 0]*src0_1[i0 + 0];
|
||||
dst_1[i0 + 1] = norm_1*c_1[i0 + 1]*src0_1[i0 + 1];
|
||||
}
|
||||
if (i1 < ne1_2) {
|
||||
dst_2[i0 + 0] = norm_2*c_2[i0 + 0]*src0_2[i0 + 0];
|
||||
dst_2[i0 + 1] = norm_2*c_2[i0 + 1]*src0_2[i0 + 1];
|
||||
}
|
||||
dst[i0 + 0] = norm*c[i0 + 0]*src0[i0 + 0];
|
||||
dst[i0 + 1] = norm*c[i0 + 1]*src0[i0 + 1];
|
||||
return;
|
||||
}
|
||||
|
||||
const float cos_theta = src1[i2*ne0 + i0 + 0];
|
||||
const float sin_theta = src1[i2*ne0 + i0 + 1];
|
||||
|
||||
if (i1 < ne1_1) {
|
||||
const float x0 = norm_1*c_1[i0/2 + 0]*src0_1[i0/2 + 0];
|
||||
const float x1 = norm_1*c_1[i0/2 + n_dims/2]*src0_1[i0/2 + n_dims/2];
|
||||
dst_1[i0/2 + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_1[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
if (i1 < ne1_2) {
|
||||
const float x0 = norm_2*c_2[i0/2 + 0]*src0_2[i0/2 + 0];
|
||||
const float x1 = norm_2*c_2[i0/2 + n_dims/2]*src0_2[i0/2 + n_dims/2];
|
||||
dst_2[i0/2 + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_2[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
const float x0 = norm*c[i0/2 + 0]*src0[i0/2 + 0];
|
||||
const float x1 = norm*c[i0/2 + n_dims/2]*src0[i0/2 + n_dims/2];
|
||||
dst[i0/2 + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
}
|
||||
|
||||
@@ -540,8 +522,7 @@ static void fused_rms_rope_neox_fast_cuda(const float * src0_1, const float * sr
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
GGML_ASSERT(ne0 <= 2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
int ne1 = std::max(ne1_1, ne1_2);
|
||||
const dim3 block_nums(ne1, 1, ne2);
|
||||
const dim3 block_nums(ne2, 1, ne1_1 + ne1_2);
|
||||
fused_rms_rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0_1, src0_2, src1, c_1, c_2, dst_1, dst_2, ne0, ne1_1, ne1_2,
|
||||
s01_1, s02_1, s01_2, s02_2, n_dims, eps);
|
||||
}
|
||||
@@ -974,22 +955,26 @@ bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
|
||||
const auto src0_1 = rms_1->src[0];
|
||||
const auto src0_2 = rms_2->src[0];
|
||||
const auto c_1 = rms_1->src[1];
|
||||
const auto c_2 = rms_2->src[1];
|
||||
|
||||
if (src0_1->type != GGML_TYPE_F32) return false;
|
||||
if (src0_2->type != GGML_TYPE_F32) return false;
|
||||
if (dst1->type != GGML_TYPE_F32) return false;
|
||||
if (dst2->type != GGML_TYPE_F32) return false;
|
||||
if (src1->type != dst1->type) return false;
|
||||
if (c_1->type != GGML_TYPE_F32) return false;
|
||||
if (c_2->type != GGML_TYPE_F32) return false;
|
||||
|
||||
if (src0_1->ne[0] != src0_2->ne[0]) return false;
|
||||
if (src0_1->ne[2] != src0_2->ne[2]) return false;
|
||||
if (src0_1->ne[3] != src0_2->ne[3]) return false;
|
||||
if (src0_1->ne[0] > 2*CUDA_ROPE_BLOCK_SIZE) return false;
|
||||
|
||||
GGML_ASSERT(ggml_nrows(rms_1->src[1]) == 1);
|
||||
GGML_ASSERT(ggml_nrows(rms_2->src[1]) == 1);
|
||||
GGML_ASSERT(rms_1->src[1]->ne[0] == src0_1->ne[0]);
|
||||
GGML_ASSERT(rms_2->src[1]->ne[0] == src0_2->ne[0]);
|
||||
GGML_ASSERT(ggml_nrows(c_1) == 1);
|
||||
GGML_ASSERT(ggml_nrows(c_2) == 1);
|
||||
GGML_ASSERT(c_1->ne[0] == src0_1->ne[0]);
|
||||
GGML_ASSERT(c_2->ne[0] == src0_2->ne[0]);
|
||||
|
||||
const int n_dims = ((const int32_t *) src1->op_params)[1];
|
||||
const int mode = ((const int32_t *) src1->op_params)[2];
|
||||
@@ -1023,7 +1008,7 @@ bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
// compute
|
||||
fused_rms_rope_neox_fast_cuda(
|
||||
(const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
|
||||
(const float *)rms_1->src[1]->data, (const float *)rms_2->src[1]->data,
|
||||
(const float *)c_1->data, (const float *)c_2->data,
|
||||
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, eps1,
|
||||
ctx.stream());
|
||||
return true;
|
||||
|
||||
Reference in New Issue
Block a user