Fused rms_norm WIP

This commit is contained in:
Iwan Kawrakow
2024-09-08 06:38:58 +03:00
parent 5bbbfc62da
commit 03fa830c5f
4 changed files with 21 additions and 51 deletions

View File

@@ -1164,14 +1164,12 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
float eps);
GGML_API struct ggml_tensor * ggml_fused_rms_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
float eps);
// group normalize along ne0*ne1*n_groups

View File

@@ -197,7 +197,6 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
}
}
//fused_rms_norm_f32_cuda(src0_d, src1_d, src2_d, dst_d, ne00, nrows, eps, ne0, ne1, ne2, nb1, nb2, stream);
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
@@ -271,20 +270,24 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
}
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (!dst->src[1] && !dst->src[2]) {
if (!dst->src[1]) {
ggml_cuda_op_rms_norm(ctx, dst);
return;
}
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
GGML_ASSERT(ggml_nrows(src1) == 1);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
@@ -292,7 +295,5 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const float * src1_d = (const float *)dst->src[1]->data;
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
}

View File

@@ -5743,18 +5743,18 @@ static struct ggml_tensor * ggml_fused_rms_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
float eps,
bool inplace) {
if (!b && !c) {
if (!b) {
return ggml_rms_norm_impl(ctx, a, eps, inplace);
}
//printf("%s: %zd x %zd x %zd %zd", __func__, a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
//if (b) printf(", b = %zd x %zd x %zd %zd, ", b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
//if (c) printf(", c = %zd x %zd x %zd %zd, ", c->ne[0], c->ne[1], c->ne[2], c->ne[3]);
//printf("\n");
if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) {
struct ggml_tensor * result = ggml_rms_norm_impl(ctx, a, eps, inplace);
result = ggml_mul_impl(ctx, result, b, inplace);
return result;
}
bool is_node = false;
@@ -5770,7 +5770,6 @@ static struct ggml_tensor * ggml_fused_rms_norm_impl(
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return result;
}
@@ -5779,18 +5778,16 @@ struct ggml_tensor * ggml_fused_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
float eps) {
return ggml_fused_rms_norm_impl(ctx, a, b, c, eps, false);
return ggml_fused_rms_norm_impl(ctx, a, b, eps, false);
}
struct ggml_tensor * ggml_fused_rms_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
float eps) {
return ggml_fused_rms_norm_impl(ctx, a, b, c, eps, true);
return ggml_fused_rms_norm_impl(ctx, a, b, eps, true);
}
// ggml_rms_norm_back
@@ -12517,9 +12514,8 @@ static void ggml_compute_forward_fused_rms_norm_f32(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src2 = dst->src[2];
if (!src1 && !src2) {
if (!src1) {
ggml_compute_forward_rms_norm_f32(params, dst);
return;
}
@@ -12527,8 +12523,9 @@ static void ggml_compute_forward_fused_rms_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(!src1 || src1->nb[0] == sizeof(float));
GGML_ASSERT(!src2 || src2->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
GGML_ASSERT(ggml_nrows(src1) == 1);
const int ith = params->ith;
const int nth = params->nth;
@@ -12555,37 +12552,11 @@ static void ggml_compute_forward_fused_rms_norm_f32(
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
const float scale = 1.0f/sqrtf(mean + eps);
ggml_vec_mul_f32(ne00, y, x, (const float *)src1->data);
ggml_vec_scale_f32(ne00, y, scale);
if (src1) {
const int64_t i13 = i03 % src1->ne[3];
const int64_t i12 = i02 % src1->ne[2];
const int64_t i11 = i01 % src1->ne[1];
const int64_t nr0 = ne00 / src1->ne[0];
float * src1_ptr = (float *) ((char *) src1->data + i13*src1->nb[3] + i12*src1->nb[2] + i11*src1->nb[1]);
for (int64_t r = 0 ; r < nr0; ++r) {
ggml_vec_mul_f32(src1->ne[0], y + r*src1->ne[0], y + r*src1->ne[0], src1_ptr);
}
}
if (src2) {
const int64_t i23 = i03 % src2->ne[3];
const int64_t i22 = i02 % src2->ne[2];
const int64_t i21 = i01 % src2->ne[1];
const int64_t nr0 = ne00 / src2->ne[0];
float * src2_ptr = (float *) ((char *) src2->data + i23*src2->nb[3] + i22*src2->nb[2] + i21*src2->nb[1]);
for (int64_t r = 0 ; r < nr0; ++r) {
ggml_vec_mul_f32(src1->ne[0], y + r*src1->ne[0], y + r*src1->ne[0], src2_ptr);
}
}
}
}
}

View File

@@ -7988,8 +7988,8 @@ static struct ggml_tensor * llm_build_norm(
const llm_build_cb & cb,
int il, float scale_eps = 1) {
if (type == LLM_NORM_RMS) {
cur = ggml_fused_rms_norm(ctx, cur, mw, mb, scale_eps * hparams.f_norm_rms_eps);
if (type == LLM_NORM_RMS && !mb) {
cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps);
cb(cur, "fused_norm", il);
return cur;
}