Fused rms_norm WIP

This commit is contained in:
Iwan Kawrakow
2024-09-07 21:38:49 +03:00
parent 4d5c76b977
commit 889cda0bba
4 changed files with 134 additions and 0 deletions

View File

@@ -2248,6 +2248,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_RMS_NORM:
ggml_cuda_op_rms_norm(ctx, dst);
break;
case GGML_OP_FUSED_RMS_NORM:
ggml_cuda_op_fused_rms_norm(ctx, dst);
break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
@@ -2871,6 +2874,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
//case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SOFTCAP:
case GGML_OP_SQR:

View File

@@ -131,6 +131,77 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
}
template <int block_size>
static __global__ void fused_rms_norm_f32(const float * x, const float * y, const float * z, float * dst, const int ncols,
const int64_t ne0[4], const int64_t ne1[4], const int64_t ne2[4], const size_t nb1[4], const size_t nb2[4], const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
int64_t i03 = row/ne0[3];
int64_t i02 = (row - i03*ne0[3])/ne0[2];
int64_t i01 = (row - i03*ne0[3] - i02*ne0[2])/ne0[1];
if (y && z) {
int64_t i13 = i03 % ne1[3];
int64_t i12 = i02 % ne1[2];
int64_t i11 = i01 % ne1[1];
int64_t i23 = i03 % ne2[3];
int64_t i22 = i02 % ne2[2];
int64_t i21 = i01 % ne2[1];
const float * yr = (const float *)((const char *)x + i13*nb1[3] + i12*nb1[2] + i11*nb1[11]);
const float * zr = (const float *)((const char *)z + i23*nb2[3] + i22*nb2[2] + i21*nb1[11]);
for (int col = tid; col < ncols; col += block_size) {
int64_t i01 = col % ne1[0];
int64_t i02 = col % ne2[0];
dst[row*ncols + col] = scale * yr[i01] * x[row*ncols + col] + zr[i02];
}
}
else if (y) {
int64_t i13 = i03 % ne1[3];
int64_t i12 = i02 % ne1[2];
int64_t i11 = i01 % ne1[1];
const float * yr = (const float *)((const char *)x + i13*nb1[3] + i12*nb1[2] + i11*nb1[11]);
for (int col = tid; col < ncols; col += block_size) {
int64_t i01 = col % ne1[0];
dst[row*ncols + col] = scale * yr[i01] * x[row*ncols + col];
}
}
else {
int64_t i23 = i03 % ne2[3];
int64_t i22 = i02 % ne2[2];
int64_t i21 = i01 % ne2[1];
const float * zr = (const float *)((const char *)z + i23*nb2[3] + i22*nb2[2] + i21*nb1[11]);
for (int col = tid; col < ncols; col += block_size) {
int64_t i02 = col % ne2[0];
dst[row*ncols + col] = scale * x[row*ncols + col] + zr[i02];
}
}
}
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
@@ -163,6 +234,20 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
}
}
//fused_rms_norm_f32_cuda(src0_d, src1_d, src2_d, dst_d, ne00, nrows, eps, ne0, ne1, ne2, nb1, nb2, stream);
static void fused_rms_norm_f32_cuda(const float * x, const float * y, const float * z, float * dst,
const int ncols, const int nrows, const float eps, const int64_t ne0[4], const int64_t ne1[4], const int64_t ne2[4],
const size_t nb1[4], const size_t nb2[4], cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, z, dst, ncols, ne0, ne1, ne2, nb1, nb2, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, z, dst, ncols, ne0, ne1, ne2, nb1, nb2, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -222,3 +307,41 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
}
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (!dst->src[1] && !dst->src[2]) {
ggml_cuda_op_rms_norm(ctx, dst);
return;
}
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(!dst->src[2] || dst->src[2]->type == GGML_TYPE_F32);
if (dst->src[1] && dst->src[2]) {
GGML_ASSERT(dst->src[1]->ne[0] == dst->src[2]->ne[0]);
}
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const float * src1_d = dst->src[1] ? (const float *)dst->src[1]->data : nullptr;
const float * src2_d = dst->src[2] ? (const float *)dst->src[2]->data : nullptr;
auto ne0 = src0->ne;
auto ne1 = dst->src[1] ? dst->src[1]->ne : nullptr;
auto ne2 = dst->src[2] ? dst->src[2]->ne : nullptr;
auto nb1 = dst->src[1] ? dst->src[1]->nb : nullptr;
auto nb2 = dst->src[2] ? dst->src[2]->nb : nullptr;
fused_rms_norm_f32_cuda(src0_d, src1_d, src2_d, dst_d, ne00, nrows, eps, ne0, ne1, ne2, nb1, nb2, stream);
}

View File

@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -5751,6 +5751,11 @@ static struct ggml_tensor * ggml_fused_rms_norm_impl(
return ggml_rms_norm_impl(ctx, a, eps, inplace);
}
//printf("%s: %zd x %zd x %zd %zd", __func__, a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
//if (b) printf(", b = %zd x %zd x %zd %zd, ", b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
//if (c) printf(", c = %zd x %zd x %zd %zd, ", c->ne[0], c->ne[1], c->ne[2], c->ne[3]);
//printf("\n");
bool is_node = false;
if (!inplace && (a->grad)) {