Fused rms+rms+rope+rope (neox) - not working

This commit is contained in:
Iwan Kawrakow
2025-11-02 09:48:25 +02:00
parent 623d775929
commit 332c4d6680
3 changed files with 174 additions and 1 deletions

View File

@@ -3245,7 +3245,20 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_rms_norm(ctx, dst);
break;
case GGML_OP_FUSED_RMS_NORM:
if (i + 2 < cgraph->n_nodes &&
//if (i + 6 < cgraph->n_nodes) {
// printf("=== Fused rms_norm(%s)\n", dst->name);
// for (int j = 1; j <= 6; ++j) printf(" %s(%s)\n", ggml_op_name(cgraph->nodes[i+j]->op), cgraph->nodes[i+j]->name);
//}
if (false && ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST &&
cgraph->nodes[i+4]->op == GGML_OP_ROPE_FAST &&
ggml_cuda_op_fused_rms_rope_fast(ctx, cgraph->nodes[i+3], cgraph->nodes[i+4])) {
//printf("Fused rms+rms+rope+rope\n");
i += 4;
}
else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
dst->ne[2] == 1 && cgraph->nodes[i+2]->ne[2] == 1) {

View File

@@ -202,6 +202,83 @@ static __global__ void fused_rope_neox_fast(const float * src0_1, const float *
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
static __global__ void fused_rms_rope_neox_fast(const float * src0_1, const float * src0_2, const float * src1,
const float * c_1, const float * c_2,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2,
int s01_1, int s02_1, int s01_2, int s02_2, int n_dims, float eps) {
int i0 = 2*threadIdx.y;
int i1 = blockIdx.x*blockDim.x + threadIdx.x;
int i2 = blockIdx.z*blockDim.z + threadIdx.z;
__shared__ float s_sum[WARP_SIZE];
src0_1 += i1*s01_1 + i2*s02_1;
src0_2 += i1*s01_2 + i2*s02_2;
dst_1 += ne0*(i1 + i2*ne1_1);
dst_2 += ne0*(i1 + i2*ne1_2);
float norm_1 = 1, norm_2 = 1;
if (i1 < ne1_1) {
float sum = i0 < ne0 ? src0_1[i0]*src0_1[i0] + src0_1[i0+1]*src0_1[i0+1] : 0.0f;
sum = warp_reduce_sum(sum);
if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) {
int warp_id = (i0/2) / WARP_SIZE;
int lane_id = (i0/2) % WARP_SIZE;
if (lane_id == 0) s_sum[warp_id] = sum;
__syncthreads();
sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0;
sum = warp_reduce_sum(sum);
}
norm_1 = rsqrtf(sum/ne0 + eps);
}
if (i2 < ne1_2) {
float sum = i0 < ne0 ? src0_2[i0]*src0_2[i0] + src0_2[i0+1]*src0_2[i0+1] : 0.0f;
sum = warp_reduce_sum(sum);
if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) {
int warp_id = (i0/2) / WARP_SIZE;
int lane_id = (i0/2) % WARP_SIZE;
if (lane_id == 0) s_sum[warp_id] = sum;
__syncthreads();
sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0;
sum = warp_reduce_sum(sum);
}
norm_2 = rsqrtf(sum/ne0 + eps);
}
if (i0 >= ne0) return;
if (i0 >= n_dims) {
if (i1 < ne1_1) {
dst_1[i0 + 0] = norm_1*c_1[i0 + 0]*src0_1[i0 + 0];
dst_1[i0 + 1] = norm_1*c_1[i0 + 1]*src0_1[i0 + 1];
}
if (i1 < ne1_2) {
dst_2[i0 + 0] = norm_2*c_2[i0 + 0]*src0_2[i0 + 0];
dst_2[i0 + 1] = norm_2*c_2[i0 + 1]*src0_2[i0 + 1];
}
return;
}
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
if (i1 < ne1_1) {
const float x0 = norm_1*c_1[i0/2 + 0]*src0_1[i0/2 + 0];
const float x1 = norm_1*c_1[i0/2 + n_dims/2]*src0_1[i0/2 + n_dims/2];
dst_1[i0/2 + 0] = x0*cos_theta - x1*sin_theta;
dst_1[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
if (i1 < ne1_2) {
const float x0 = norm_2*c_2[i0/2 + 0]*src0_2[i0/2 + 0];
const float x1 = norm_2*c_2[i0/2 + n_dims/2]*src0_2[i0/2 + n_dims/2];
dst_2[i0/2 + 0] = x0*cos_theta - x1*sin_theta;
dst_2[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
}
static __global__ void rope_norm_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int nelem,
int s01, int s02, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
@@ -456,6 +533,19 @@ static void fused_rope_neox_fast_cuda(const float * src0_1, const float * src0_2
s01_1, s02_1, s01_2, s02_2, n_dims);
}
static void fused_rms_rope_neox_fast_cuda(const float * src0_1, const float * src0_2, const float * src1,
const float * c_1, const float * c_2,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int ne2, int s01_1, int s02_1, int s01_2, int s02_2,
int n_dims, float eps, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
GGML_ASSERT(ne0 <= 2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
int ne1 = std::max(ne1_1, ne1_2);
const dim3 block_nums(ne1, 1, ne2);
fused_rms_rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0_1, src0_2, src1, c_1, c_2, dst_1, dst_2, ne0, ne1_1, ne1_2,
s01_1, s02_1, s01_2, s02_2, n_dims, eps);
}
static void fused_rope_norm_fast_cuda(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int ne2, int s01_1, int s02_1, int s01_2, int s02_2,
int n_dims, cudaStream_t stream) {
@@ -870,3 +960,71 @@ bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor *
}
return true;
}
bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) {
if (dst1->src[1] != dst2->src[1]) return false;
const auto rms_1 = dst1->src[0];
const auto rms_2 = dst2->src[0];
const auto src1 = dst1->src[1];
if (rms_1->op != GGML_OP_FUSED_RMS_NORM) return false;
if (rms_2->op != GGML_OP_FUSED_RMS_NORM) return false;
const auto src0_1 = rms_1->src[0];
const auto src0_2 = rms_2->src[0];
if (src0_1->type != GGML_TYPE_F32) return false;
if (src0_2->type != GGML_TYPE_F32) return false;
if (dst1->type != GGML_TYPE_F32) return false;
if (dst2->type != GGML_TYPE_F32) return false;
if (src1->type != dst1->type) return false;
if (src0_1->ne[0] != src0_2->ne[0]) return false;
if (src0_1->ne[2] != src0_2->ne[2]) return false;
if (src0_1->ne[3] != src0_2->ne[3]) return false;
if (src0_1->ne[0] > 2*CUDA_ROPE_BLOCK_SIZE) return false;
GGML_ASSERT(ggml_nrows(rms_1->src[1]) == 1);
GGML_ASSERT(ggml_nrows(rms_2->src[1]) == 1);
GGML_ASSERT(rms_1->src[1]->ne[0] == src0_1->ne[0]);
GGML_ASSERT(rms_2->src[1]->ne[0] == src0_2->ne[0]);
const int n_dims = ((const int32_t *) src1->op_params)[1];
const int mode = ((const int32_t *) src1->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_vision || is_mrope) return false; // not implemented
if (!is_neox) return false; // TODO
float eps1, eps2;
memcpy(&eps1, rms_1->op_params, sizeof(float));
memcpy(&eps2, rms_2->op_params, sizeof(float));
if (eps1 != eps2) return false;
const int64_t ne00 = src0_1->ne[0]; // head dims
const int64_t ne02 = src0_1->ne[2]; // num tokens
const int64_t ne01_1 = src0_1->ne[1]; // num heads
const int64_t ne01_2 = src0_2->ne[1]; // num heads
const size_t s01_1 = src0_1->nb[1] / ggml_type_size(src0_1->type);
const size_t s02_1 = src0_1->nb[2] / ggml_type_size(src0_1->type);
const size_t s01_2 = src0_2->nb[1] / ggml_type_size(src0_2->type);
const size_t s02_2 = src0_2->nb[2] / ggml_type_size(src0_2->type);
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}
// compute
fused_rms_rope_neox_fast_cuda(
(const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
(const float *)rms_1->src[1]->data, (const float *)rms_2->src[1]->data,
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, eps1,
ctx.stream());
return true;
}

View File

@@ -11,3 +11,5 @@ void ggml_cuda_op_rope_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);
bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);