mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
Allow for f16 source in fused_rms_norm
This commit is contained in:
@@ -176,15 +176,15 @@ static __global__ void rms_norm_f32_nc(
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
|
||||
template <int block_size, typename src_t>
|
||||
static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
const float xi = (float)x[row*ncols + col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
@@ -206,13 +206,13 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
|
||||
dst[row*ncols + col] = scale * y[col] * (float)x[row*ncols + col];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
template <int block_size, typename src_t>
|
||||
static __global__ void fused_rms_norm_f32_nc(
|
||||
const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const src_t * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
@@ -229,7 +229,7 @@ static __global__ void fused_rms_norm_f32_nc(
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
const float xi = (float)x[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
@@ -257,7 +257,7 @@ static __global__ void fused_rms_norm_f32_nc(
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * y[col] * x[col];
|
||||
dst[col] = scale * y[col] * (float)x[col];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,7 +307,8 @@ static void rms_norm_f32_nc_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
|
||||
template <typename src_t>
|
||||
static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst,
|
||||
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
constexpr int kBlockSize = 256;
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
@@ -331,8 +332,9 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src_t>
|
||||
static void fused_rms_norm_f32_nc_cuda(
|
||||
const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const src_t * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
@@ -432,7 +434,7 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
|
||||
@@ -445,14 +447,22 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
|
||||
if (ggml_is_contiguous(src0)) {
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
|
||||
} else {
|
||||
fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
} else {
|
||||
auto ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(src0->nb[0] == ts0);
|
||||
auto s01 = src0->nb[1] / ts0;
|
||||
auto s02 = src0->nb[2] / ts0;
|
||||
auto s03 = src0->nb[3] / ts0;
|
||||
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
|
||||
} else {
|
||||
fused_rms_norm_f32_nc_cuda((const half *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7284,7 +7284,19 @@ static struct ggml_tensor * ggml_fused_rms_norm_impl(
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result;
|
||||
if (inplace) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
result = ggml_view_tensor(ctx, a);
|
||||
} else {
|
||||
if (a->type == GGML_TYPE_F32) {
|
||||
result = ggml_dup_tensor(ctx, a);
|
||||
} else {
|
||||
result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, &eps, sizeof(eps));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user