Allow for f16 source in fused_rms_norm

This commit is contained in:
Kawrakow
2025-11-27 14:45:35 +00:00
parent fbbac10872
commit ed67bcbb2a
2 changed files with 36 additions and 14 deletions

View File

@@ -176,15 +176,15 @@ static __global__ void rms_norm_f32_nc(
}
}
template <int block_size>
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
template <int block_size, typename src_t>
static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
const float xi = (float)x[row*ncols + col];
tmp += xi * xi;
}
@@ -206,13 +206,13 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
dst[row*ncols + col] = scale * y[col] * (float)x[row*ncols + col];
}
}
template <int block_size>
template <int block_size, typename src_t>
static __global__ void fused_rms_norm_f32_nc(
const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const src_t * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@@ -229,7 +229,7 @@ static __global__ void fused_rms_norm_f32_nc(
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
const float xi = (float)x[col];
tmp += xi * xi;
}
@@ -257,7 +257,7 @@ static __global__ void fused_rms_norm_f32_nc(
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * y[col] * x[col];
dst[col] = scale * y[col] * (float)x[col];
}
}
@@ -307,7 +307,8 @@ static void rms_norm_f32_nc_cuda(
}
}
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
template <typename src_t>
static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
constexpr int kBlockSize = 256;
GGML_ASSERT(ncols % WARP_SIZE == 0);
@@ -331,8 +332,9 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds
}
}
template <typename src_t>
static void fused_rms_norm_f32_nc_cuda(
const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const src_t * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
@@ -432,7 +434,7 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
@@ -445,14 +447,22 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
if (ggml_is_contiguous(src0)) {
const int64_t nrows = ggml_nrows(src0);
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
if (src0->type == GGML_TYPE_F32) {
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
} else {
fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
}
} else {
auto ts0 = ggml_type_size(src0->type);
GGML_ASSERT(src0->nb[0] == ts0);
auto s01 = src0->nb[1] / ts0;
auto s02 = src0->nb[2] / ts0;
auto s03 = src0->nb[3] / ts0;
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
if (src0->type == GGML_TYPE_F32) {
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
} else {
fused_rms_norm_f32_nc_cuda((const half *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
}
}
}

View File

@@ -7284,7 +7284,19 @@ static struct ggml_tensor * ggml_fused_rms_norm_impl(
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result;
if (inplace) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
result = ggml_view_tensor(ctx, a);
} else {
if (a->type == GGML_TYPE_F32) {
result = ggml_dup_tensor(ctx, a);
} else {
result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
}
}
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, &eps, sizeof(eps));