Fused rms_norm WIP

This commit is contained in:
Iwan Kawrakow
2024-09-07 23:04:31 +03:00
parent 889cda0bba
commit 5bbbfc62da
2 changed files with 11 additions and 60 deletions

View File

@@ -2874,7 +2874,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
//case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SOFTCAP:
case GGML_OP_SQR:

View File

@@ -132,8 +132,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
template <int block_size>
static __global__ void fused_rms_norm_f32(const float * x, const float * y, const float * z, float * dst, const int ncols,
const int64_t ne0[4], const int64_t ne1[4], const int64_t ne2[4], const size_t nb1[4], const size_t nb2[4], const float eps) {
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
@@ -161,44 +160,8 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, cons
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
int64_t i03 = row/ne0[3];
int64_t i02 = (row - i03*ne0[3])/ne0[2];
int64_t i01 = (row - i03*ne0[3] - i02*ne0[2])/ne0[1];
if (y && z) {
int64_t i13 = i03 % ne1[3];
int64_t i12 = i02 % ne1[2];
int64_t i11 = i01 % ne1[1];
int64_t i23 = i03 % ne2[3];
int64_t i22 = i02 % ne2[2];
int64_t i21 = i01 % ne2[1];
const float * yr = (const float *)((const char *)x + i13*nb1[3] + i12*nb1[2] + i11*nb1[11]);
const float * zr = (const float *)((const char *)z + i23*nb2[3] + i22*nb2[2] + i21*nb1[11]);
for (int col = tid; col < ncols; col += block_size) {
int64_t i01 = col % ne1[0];
int64_t i02 = col % ne2[0];
dst[row*ncols + col] = scale * yr[i01] * x[row*ncols + col] + zr[i02];
}
}
else if (y) {
int64_t i13 = i03 % ne1[3];
int64_t i12 = i02 % ne1[2];
int64_t i11 = i01 % ne1[1];
const float * yr = (const float *)((const char *)x + i13*nb1[3] + i12*nb1[2] + i11*nb1[11]);
for (int col = tid; col < ncols; col += block_size) {
int64_t i01 = col % ne1[0];
dst[row*ncols + col] = scale * yr[i01] * x[row*ncols + col];
}
}
else {
int64_t i23 = i03 % ne2[3];
int64_t i22 = i02 % ne2[2];
int64_t i21 = i01 % ne2[1];
const float * zr = (const float *)((const char *)z + i23*nb2[3] + i22*nb2[2] + i21*nb1[11]);
for (int col = tid; col < ncols; col += block_size) {
int64_t i02 = col % ne2[0];
dst[row*ncols + col] = scale * x[row*ncols + col] + zr[i02];
}
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
}
}
@@ -235,16 +198,15 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
}
//fused_rms_norm_f32_cuda(src0_d, src1_d, src2_d, dst_d, ne00, nrows, eps, ne0, ne1, ne2, nb1, nb2, stream);
static void fused_rms_norm_f32_cuda(const float * x, const float * y, const float * z, float * dst,
const int ncols, const int nrows, const float eps, const int64_t ne0[4], const int64_t ne1[4], const int64_t ne2[4],
const size_t nb1[4], const size_t nb2[4], cudaStream_t stream) {
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, z, dst, ncols, ne0, ne1, ne2, nb1, nb2, eps);
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, z, dst, ncols, ne0, ne1, ne2, nb1, nb2, eps);
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
}
}
@@ -322,11 +284,7 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(!dst->src[2] || dst->src[2]->type == GGML_TYPE_F32);
if (dst->src[1] && dst->src[2]) {
GGML_ASSERT(dst->src[1]->ne[0] == dst->src[2]->ne[0]);
}
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
@@ -334,14 +292,7 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const float * src1_d = dst->src[1] ? (const float *)dst->src[1]->data : nullptr;
const float * src2_d = dst->src[2] ? (const float *)dst->src[2]->data : nullptr;
const float * src1_d = (const float *)dst->src[1]->data;
auto ne0 = src0->ne;
auto ne1 = dst->src[1] ? dst->src[1]->ne : nullptr;
auto ne2 = dst->src[2] ? dst->src[2]->ne : nullptr;
auto nb1 = dst->src[1] ? dst->src[1]->nb : nullptr;
auto nb2 = dst->src[2] ? dst->src[2]->nb : nullptr;
fused_rms_norm_f32_cuda(src0_d, src1_d, src2_d, dst_d, ne00, nrows, eps, ne0, ne1, ne2, nb1, nb2, stream);
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
}