diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 98d33ebc..c4619e60 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -176,15 +176,15 @@ static __global__ void rms_norm_f32_nc( } } -template -static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) { +template +static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; + const float xi = (float)x[row*ncols + col]; tmp += xi * xi; } @@ -206,13 +206,13 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = scale * y[col] * x[row*ncols + col]; + dst[row*ncols + col] = scale * y[col] * (float)x[row*ncols + col]; } } -template +template static __global__ void fused_rms_norm_f32_nc( - const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const src_t * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -229,7 +229,7 @@ static __global__ void fused_rms_norm_f32_nc( float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[col]; + const float xi = (float)x[col]; tmp += xi * xi; } @@ -257,7 +257,7 @@ static __global__ void fused_rms_norm_f32_nc( const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[col] = scale * y[col] * x[col]; + dst[col] = scale * y[col] * (float)x[col]; } } @@ -307,7 +307,8 @@ static void rms_norm_f32_nc_cuda( } } -static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst, +template +static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { constexpr int kBlockSize = 256; GGML_ASSERT(ncols % WARP_SIZE == 0); @@ -331,8 +332,9 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds } } +template static void fused_rms_norm_f32_nc_cuda( - const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const src_t * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { @@ -432,7 +434,7 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->ne[0] == src1->ne[0]); @@ -445,14 +447,22 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * if (ggml_is_contiguous(src0)) { const int64_t nrows = ggml_nrows(src0); - fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); + if (src0->type == GGML_TYPE_F32) { + fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); + } else { + fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, stream); + } } else { auto ts0 = ggml_type_size(src0->type); GGML_ASSERT(src0->nb[0] == ts0); auto s01 = src0->nb[1] / ts0; auto s02 = src0->nb[2] / ts0; auto s03 = src0->nb[3] / ts0; - fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); + if (src0->type == GGML_TYPE_F32) { + fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); + } else { + fused_rms_norm_f32_nc_cuda((const half *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 41501453..3bddc026 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7284,7 +7284,19 @@ static struct ggml_tensor * ggml_fused_rms_norm_impl( is_node = true; } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result; + if (inplace) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + result = ggml_view_tensor(ctx, a); + } else { + if (a->type == GGML_TYPE_F32) { + result = ggml_dup_tensor(ctx, a); + } else { + result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], a->ne[1], a->ne[2], a->ne[3]); + } + } + + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &eps, sizeof(eps));