Fused fused_rms+fused_rms+rope+rope (with -mqkv)

This commit is contained in:
Iwan Kawrakow
2025-11-03 18:21:42 +02:00
parent 0dc705587a
commit 9f9866b710
2 changed files with 45 additions and 60 deletions

View File

@@ -3249,7 +3249,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
// for (int j = 1; j <= 6; ++j) printf(" %s(%s)\n", ggml_op_name(cgraph->nodes[i+j]->op), cgraph->nodes[i+j]->name);
//}
// Disabled because currently there is something wrong in the fused kernel implementation
if (false && ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST &&

View File

@@ -208,74 +208,56 @@ static __global__ void fused_rms_rope_neox_fast(const float * src0_1, const floa
int s01_1, int s02_1, int s01_2, int s02_2, int n_dims, float eps) {
int i0 = 2*threadIdx.y;
int i1 = blockIdx.x*blockDim.x + threadIdx.x;
int i2 = blockIdx.z*blockDim.z + threadIdx.z;
int i2 = blockIdx.x*blockDim.x + threadIdx.x;
int i1 = blockIdx.z*blockDim.z + threadIdx.z;
__shared__ float s_sum[WARP_SIZE];
const float * src0, *c;
float * dst;
int ne1, s01, s02;
src0_1 += i1*s01_1 + i2*s02_1;
src0_2 += i1*s01_2 + i2*s02_2;
dst_1 += ne0*(i1 + i2*ne1_1);
dst_2 += ne0*(i1 + i2*ne1_2);
float norm_1 = 1, norm_2 = 1;
if (i1 < ne1_1) {
float sum = i0 < ne0 ? src0_1[i0]*src0_1[i0] + src0_1[i0+1]*src0_1[i0+1] : 0.0f;
sum = warp_reduce_sum(sum);
if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) {
int warp_id = (i0/2) / WARP_SIZE;
int lane_id = (i0/2) % WARP_SIZE;
if (lane_id == 0) s_sum[warp_id] = sum;
__syncthreads();
sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0;
sum = warp_reduce_sum(sum);
}
norm_1 = rsqrtf(sum/ne0 + eps);
ne1 = ne1_1;
s01 = s01_1; s02 = s02_1;
src0 = src0_1 + i1*s01 + i2*s02;
dst = dst_1 + ne0*(i1 + i2*ne1);
c = c_1;
} else {
i1 -= ne1_1;
ne1 = ne1_2;
s01 = s01_2; s02 = s02_2;
src0 = src0_2 + i1*s01 + i2*s02;
dst = dst_2 + ne0*(i1 + i2*ne1);
c = c_2;
}
if (i2 < ne1_2) {
float sum = i0 < ne0 ? src0_2[i0]*src0_2[i0] + src0_2[i0+1]*src0_2[i0+1] : 0.0f;
float sum = i0 < ne0 ? src0[i0]*src0[i0] + src0[i0+1]*src0[i0+1] : 0.0f;
sum = warp_reduce_sum(sum);
if constexpr (CUDA_ROPE_BLOCK_SIZE > WARP_SIZE) {
__shared__ float s_sum[WARP_SIZE];
int warp_id = (i0/2) / WARP_SIZE;
int lane_id = (i0/2) % WARP_SIZE;
if (lane_id == 0) s_sum[warp_id] = sum;
__syncthreads();
sum = lane_id < CUDA_ROPE_BLOCK_SIZE / WARP_SIZE ? s_sum[lane_id] : 0;
sum = warp_reduce_sum(sum);
if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) {
int warp_id = (i0/2) / WARP_SIZE;
int lane_id = (i0/2) % WARP_SIZE;
if (lane_id == 0) s_sum[warp_id] = sum;
__syncthreads();
sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0;
sum = warp_reduce_sum(sum);
}
norm_2 = rsqrtf(sum/ne0 + eps);
}
float norm = rsqrtf(sum/ne0 + eps);
if (i0 >= ne0) return;
if (i0 >= n_dims) {
if (i1 < ne1_1) {
dst_1[i0 + 0] = norm_1*c_1[i0 + 0]*src0_1[i0 + 0];
dst_1[i0 + 1] = norm_1*c_1[i0 + 1]*src0_1[i0 + 1];
}
if (i1 < ne1_2) {
dst_2[i0 + 0] = norm_2*c_2[i0 + 0]*src0_2[i0 + 0];
dst_2[i0 + 1] = norm_2*c_2[i0 + 1]*src0_2[i0 + 1];
}
dst[i0 + 0] = norm*c[i0 + 0]*src0[i0 + 0];
dst[i0 + 1] = norm*c[i0 + 1]*src0[i0 + 1];
return;
}
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
if (i1 < ne1_1) {
const float x0 = norm_1*c_1[i0/2 + 0]*src0_1[i0/2 + 0];
const float x1 = norm_1*c_1[i0/2 + n_dims/2]*src0_1[i0/2 + n_dims/2];
dst_1[i0/2 + 0] = x0*cos_theta - x1*sin_theta;
dst_1[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
if (i1 < ne1_2) {
const float x0 = norm_2*c_2[i0/2 + 0]*src0_2[i0/2 + 0];
const float x1 = norm_2*c_2[i0/2 + n_dims/2]*src0_2[i0/2 + n_dims/2];
dst_2[i0/2 + 0] = x0*cos_theta - x1*sin_theta;
dst_2[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
const float x0 = norm*c[i0/2 + 0]*src0[i0/2 + 0];
const float x1 = norm*c[i0/2 + n_dims/2]*src0[i0/2 + n_dims/2];
dst[i0/2 + 0] = x0*cos_theta - x1*sin_theta;
dst[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
@@ -540,8 +522,7 @@ static void fused_rms_rope_neox_fast_cuda(const float * src0_1, const float * sr
GGML_ASSERT(ne0 % 2 == 0);
GGML_ASSERT(ne0 <= 2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
int ne1 = std::max(ne1_1, ne1_2);
const dim3 block_nums(ne1, 1, ne2);
const dim3 block_nums(ne2, 1, ne1_1 + ne1_2);
fused_rms_rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0_1, src0_2, src1, c_1, c_2, dst_1, dst_2, ne0, ne1_1, ne1_2,
s01_1, s02_1, s01_2, s02_2, n_dims, eps);
}
@@ -974,22 +955,26 @@ bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tens
const auto src0_1 = rms_1->src[0];
const auto src0_2 = rms_2->src[0];
const auto c_1 = rms_1->src[1];
const auto c_2 = rms_2->src[1];
if (src0_1->type != GGML_TYPE_F32) return false;
if (src0_2->type != GGML_TYPE_F32) return false;
if (dst1->type != GGML_TYPE_F32) return false;
if (dst2->type != GGML_TYPE_F32) return false;
if (src1->type != dst1->type) return false;
if (c_1->type != GGML_TYPE_F32) return false;
if (c_2->type != GGML_TYPE_F32) return false;
if (src0_1->ne[0] != src0_2->ne[0]) return false;
if (src0_1->ne[2] != src0_2->ne[2]) return false;
if (src0_1->ne[3] != src0_2->ne[3]) return false;
if (src0_1->ne[0] > 2*CUDA_ROPE_BLOCK_SIZE) return false;
GGML_ASSERT(ggml_nrows(rms_1->src[1]) == 1);
GGML_ASSERT(ggml_nrows(rms_2->src[1]) == 1);
GGML_ASSERT(rms_1->src[1]->ne[0] == src0_1->ne[0]);
GGML_ASSERT(rms_2->src[1]->ne[0] == src0_2->ne[0]);
GGML_ASSERT(ggml_nrows(c_1) == 1);
GGML_ASSERT(ggml_nrows(c_2) == 1);
GGML_ASSERT(c_1->ne[0] == src0_1->ne[0]);
GGML_ASSERT(c_2->ne[0] == src0_2->ne[0]);
const int n_dims = ((const int32_t *) src1->op_params)[1];
const int mode = ((const int32_t *) src1->op_params)[2];
@@ -1023,7 +1008,7 @@ bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tens
// compute
fused_rms_rope_neox_fast_cuda(
(const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
(const float *)rms_1->src[1]->data, (const float *)rms_2->src[1]->data,
(const float *)c_1->data, (const float *)c_2->data,
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, eps1,
ctx.stream());
return true;