qwen3next: add architecture support and recurrent-state fixes

This commit is contained in:
yurko
2026-02-06 12:13:09 +00:00
parent a527b5af25
commit a7df116441
28 changed files with 2729 additions and 14 deletions

View File

@@ -601,6 +601,7 @@ extern "C" {
GGML_OP_LOG,
GGML_OP_SUM,
GGML_OP_SUM_ROWS,
GGML_OP_CUMSUM,
GGML_OP_MEAN,
GGML_OP_ARGMAX,
GGML_OP_REPEAT,
@@ -611,6 +612,7 @@ extern "C" {
GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
GGML_OP_L2_NORM,
GGML_OP_FUSED_RMS_NORM,
GGML_OP_FUSED_MUL_UNARY,
GGML_OP_MULTI_ADD,
@@ -653,6 +655,8 @@ extern "C" {
GGML_OP_PAD,
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_TRI,
GGML_OP_FILL,
GGML_OP_ARGSORT,
GGML_OP_ARGSORT_THRESH,
GGML_OP_GROUPED_TOPK,
@@ -671,6 +675,7 @@ extern "C" {
GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_SOLVE_TRI,
GGML_OP_UNARY,
GGML_OP_MAP_UNARY,
@@ -710,6 +715,8 @@ extern "C" {
GGML_UNARY_OP_SILU,
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_SOFTPLUS,
GGML_UNARY_OP_SWIGLU,
GGML_UNARY_OP_SWIGLU_OAI,
GGML_UNARY_OP_GELU,
@@ -739,6 +746,13 @@ extern "C" {
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
};
enum ggml_tri_type {
GGML_TRI_TYPE_LOWER,
GGML_TRI_TYPE_UPPER,
GGML_TRI_TYPE_LOWER_DIAG,
GGML_TRI_TYPE_UPPER_DIAG,
};
// ggml object
struct ggml_object {
size_t offs;
@@ -1189,6 +1203,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_softplus(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_softplus_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// return scalar
GGML_API struct ggml_tensor * ggml_sum(
struct ggml_context * ctx,
@@ -1199,6 +1221,10 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_cumsum(
struct ggml_context * ctx,
struct ggml_tensor * a);
// mean along rows
GGML_API struct ggml_tensor * ggml_mean(
struct ggml_context * ctx,
@@ -1217,6 +1243,15 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// repeat a to specified shape
GGML_API struct ggml_tensor * ggml_repeat_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3);
// sums repetitions in a into shape of b
GGML_API struct ggml_tensor * ggml_repeat_back(
struct ggml_context * ctx,
@@ -1455,6 +1490,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_exp(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_exp_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
@@ -1514,6 +1557,17 @@ extern "C" {
int n_groups,
float eps);
// l2 normalize along rows
GGML_API struct ggml_tensor * ggml_l2_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);
// a - x
// b - dy
GGML_API struct ggml_tensor * ggml_rms_norm_back(
@@ -2283,6 +2337,23 @@ extern "C" {
int dim,
int max_period);
// convert matrix to triangular form by zeroing values outside selected half
GGML_API struct ggml_tensor * ggml_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_tri_type type);
// fill tensor with constant c
GGML_API struct ggml_tensor * ggml_fill(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);
GGML_API struct ggml_tensor * ggml_fill_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);
// sort rows
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
@@ -2426,6 +2497,15 @@ extern "C" {
struct ggml_tensor * pw,
struct ggml_tensor * ph);
// Solve Ax = B where A is triangular
GGML_API struct ggml_tensor * ggml_solve_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
bool left,
bool lower,
bool uni);
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

View File

@@ -18,9 +18,11 @@
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/convert.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/dmmv.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/fill.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
@@ -46,10 +48,13 @@
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/multiadd.cuh"
#include "ggml-cuda/hadamard.cuh"
#include "ggml-cuda/reduce.cuh"
#include "ggml-cuda/tri.cuh"
#include <algorithm>
#include <array>
@@ -3295,6 +3300,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_HARDSWISH:
ggml_cuda_op_hardswish(ctx, dst);
break;
case GGML_UNARY_OP_EXP:
ggml_cuda_op_exp(ctx, dst);
break;
case GGML_UNARY_OP_SOFTPLUS:
ggml_cuda_op_softplus(ctx, dst);
break;
default:
return -1;
}
@@ -3329,6 +3340,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GROUP_NORM:
ggml_cuda_op_group_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_cuda_op_l2_norm(ctx, dst);
break;
case GGML_OP_CONCAT:
if (fusion && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
@@ -3544,6 +3558,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_sum_rows(ctx, dst);
}
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_ARGSORT:
if (fusion && i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
@@ -3563,6 +3580,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GROUPED_TOPK:
ggml_cuda_op_grouped_topk(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
case GGML_OP_FILL:
ggml_cuda_op_fill(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
@@ -4139,6 +4168,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -4332,6 +4363,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_L2_NORM:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break;
@@ -4379,6 +4412,38 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_CUMSUM:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_TRI:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[0]->type == op->type;
case GGML_OP_FILL:
return ggml_is_contiguous(op) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_SOLVE_TRI:
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
op->src[0]->ne[0] == op->src[0]->ne[1] &&
op->src[0]->ne[1] == op->src[1]->ne[1] &&
op->src[0]->ne[2] == op->src[1]->ne[2] &&
op->src[0]->ne[3] == op->src[1]->ne[3];
case GGML_OP_SSM_CONV:
return op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->src[2]->type == GGML_TYPE_F32 &&
op->src[3]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_F32 &&
op->src[0]->nb[0] == sizeof(float) &&
op->src[1]->nb[0] == sizeof(float) &&
op->src[2]->nb[0] == sizeof(float) &&
op->src[3]->nb[0] == sizeof(int32_t) &&
op->src[2]->ne[0] == op->src[0]->ne[0] + 1 &&
op->src[2]->ne[1] == op->src[0]->ne[1] &&
op->src[1]->ne[0] == op->src[0]->ne[1] &&
op->src[3]->ne[0] == op->src[0]->ne[2];
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;

View File

@@ -0,0 +1,76 @@
#include "cumsum.cuh"
#define CUDA_CUMSUM_BLOCK_SIZE 256
static __global__ void cumsum_f32_kernel(
const float * src, float * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t d0, const int64_t d1, const int64_t d2, const int64_t d3) {
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int64_t i3 = blockIdx.z;
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
return;
}
const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3;
extern __shared__ float s_scan[];
float carry = 0.0f;
for (int64_t start = 0; start < ne00; start += blockDim.x) {
const int tile_n = (int) ((ne00 - start) < (int64_t) blockDim.x ? (ne00 - start) : (int64_t) blockDim.x);
float value = 0.0f;
if (threadIdx.x < tile_n) {
value = src_row[(start + threadIdx.x) * s00];
}
s_scan[threadIdx.x] = value;
__syncthreads();
for (int offset = 1; offset < blockDim.x; offset <<= 1) {
float add = 0.0f;
if (threadIdx.x >= offset) {
add = s_scan[threadIdx.x - offset];
}
__syncthreads();
if (threadIdx.x >= offset) {
s_scan[threadIdx.x] += add;
}
__syncthreads();
}
if (threadIdx.x < tile_n) {
dst_row[(start + threadIdx.x) * d0] = s_scan[threadIdx.x] + carry;
}
__syncthreads();
if (threadIdx.x == tile_n - 1) {
carry += s_scan[threadIdx.x];
}
__syncthreads();
}
}
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
int block_size = WARP_SIZE;
while (block_size < src0->ne[0] && block_size < CUDA_CUMSUM_BLOCK_SIZE) {
block_size <<= 1;
}
dim3 grid_dims(src0->ne[1], src0->ne[2], src0->ne[3]);
cumsum_f32_kernel<<<grid_dims, block_size, block_size * sizeof(float), ctx.stream()>>>(
(const float *) src0->data,
(float *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0] / sizeof(float), src0->nb[1] / sizeof(float), src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
dst->nb[0] / sizeof(float), dst->nb[1] / sizeof(float), dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float));
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,34 @@
#include "fill.cuh"
#include "convert.cuh"
#define CUDA_FILL_BLOCK_SIZE 256
template <typename T>
static __global__ void fill_kernel(T * dst, const int64_t k, const T value) {
const int64_t i = (int64_t) blockDim.x * blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = value;
}
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst));
float value = 0.0f;
memcpy(&value, dst->op_params, sizeof(float));
const int64_t k = ggml_nelements(dst);
const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
switch (dst->type) {
case GGML_TYPE_F32:
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, ctx.stream()>>>((float *) dst->data, k, value);
break;
case GGML_TYPE_F16:
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, ctx.stream()>>>((half *) dst->data, k, ggml_cuda_cast<half>(value));
break;
default:
GGML_ABORT("unsupported type");
}
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -185,6 +185,38 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
}
template <int block_size>
static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f;
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row * ncols + col];
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = lane_id < block_size / WARP_SIZE ? s_sum[lane_id] : 0.0f;
tmp = warp_reduce_sum(tmp);
}
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[row * ncols + col] = scale * x[row * ncols + col];
}
}
template <int block_size>
static __global__ void rms_norm_f32_nc(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
@@ -230,6 +262,49 @@ static __global__ void rms_norm_f32_nc(
}
}
template <int block_size>
static __global__ void l2_norm_f32_nc(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample * stride_sample + channel * stride_channel + row * stride_row;
dst += ((sample * nchannels + channel) * nrows + row) * ncols;
float tmp = 0.0f;
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
template <int block_size, typename src_t>
static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -387,6 +462,31 @@ static void rms_norm_f32_nc_cuda(
}
}
static void l2_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
constexpr int kBlockSize = 256;
if (ncols < 1024) {
const dim3 block_dims(kBlockSize, 1, 1);
l2_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
static void l2_norm_f32_nc_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
l2_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
template <typename src_t>
static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, bool is_norm, cudaStream_t stream) {
@@ -527,6 +627,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
}
}
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float eps = 0.0f;
memcpy(&eps, dst->op_params, sizeof(float));
const int64_t ne00 = src0->ne[0];
if (ggml_is_contiguous(src0)) {
const int64_t nrows = ggml_nrows(src0);
l2_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
} else {
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(src0->nb[0] == ts0);
const int64_t s01 = src0->nb[1] / ts0;
const int64_t s02 = src0->nb[2] / ts0;
const int64_t s03 = src0->nb[3] / ts0;
l2_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
}
}
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm) {
if (!dst->src[1]) {
ggml_cuda_op_rms_norm(ctx, dst);

View File

@@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm = false);
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst);

View File

@@ -0,0 +1,290 @@
#include "solve_tri.cuh"
#define MAX_N_FAST 64
#define MAX_K_FAST 32
static __global__ void get_batch_pointers(
const float * A,
float * X,
const float ** A_ptrs,
float ** X_ptrs,
int64_t ne02,
int64_t total_batches,
size_t s02,
size_t s03,
size_t s2,
size_t s3) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_batches) {
return;
}
const int64_t i3 = idx / ne02;
const int64_t i2 = idx % ne02;
A_ptrs[idx] = A + i3 * s03 + i2 * s02;
X_ptrs[idx] = X + i3 * s3 + i2 * s2;
}
static void solve_tri_f32_cublas(
ggml_backend_cuda_context & ctx,
const float * A,
const float * B,
float * X,
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t s02,
size_t s03,
size_t s2,
size_t s3,
cudaStream_t stream) {
const float alpha = 1.0f;
const int64_t total_batches = ne02 * ne03;
if (total_batches == 0) {
return;
}
if (X != B) {
const int64_t total_elements = int64_t(n) * int64_t(k) * total_batches;
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements * sizeof(float), cudaMemcpyDeviceToDevice, stream));
}
const int id = ggml_cuda_get_device();
ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
const float ** A_ptrs_dev = A_ptrs_alloc.get();
float ** X_ptrs_dev = X_ptrs_alloc.get();
constexpr int kBlockSize = 256;
const int blocks = (total_batches + kBlockSize - 1) / kBlockSize;
get_batch_pointers<<<blocks, kBlockSize, 0, stream>>>(
A, X, A_ptrs_dev, X_ptrs_dev,
ne02, total_batches, s02, s03, s2, s3);
cublasHandle_t handle = ctx.cublas_handle(id);
CUBLAS_CHECK(cublasSetStream(handle, stream));
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif
CUBLAS_CHECK(cublasStrsmBatched(
handle,
CUBLAS_SIDE_RIGHT,
CUBLAS_FILL_MODE_UPPER,
CUBLAS_OP_N,
CUBLAS_DIAG_NON_UNIT,
k,
n,
&alpha,
A_ptrs_dev,
n,
X_ptrs_dev,
k,
total_batches));
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
}
// Small triangular systems are faster with a custom kernel than with batched cublasStrsm.
template <int n_template, int k_template>
static __global__ void solve_tri_f32_fast(
const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const int64_t ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n_arg,
const int k_arg) {
const int n = n_template == 0 ? n_arg : n_template;
const int k = k_template == 0 ? k_arg : k_template;
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int col_idx = threadIdx.y;
if (col_idx >= k) {
return;
}
const int64_t i03 = batch_idx / ne02;
const int64_t i02 = batch_idx - i03 * ne02;
const float * A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
const int i0 = i + offset;
if (i0 < n * n) {
sA[i0] = A_batch[i0];
}
}
__syncthreads();
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
const int half = WARP_SIZE;
const int nrows_low = (n < half) ? n : half;
#pragma unroll
for (int row = 0; row < nrows_low; ++row) {
float sum = 0.0f;
if (lane < row) {
sum += sA[row * n + lane] * x_low;
}
sum = warp_reduce_sum(sum);
if (lane == row) {
x_low = (x_low - sum) / sA[row * n + row];
}
}
#pragma unroll
for (int row = half; row < n; ++row) {
float sum = sA[row * n + lane] * x_low;
const int j = half + lane;
if (j < row) {
sum += sA[row * n + j] * x_high;
}
sum = warp_reduce_sum(sum);
if (lane == row - half) {
x_high = (x_high - sum) / sA[row * n + row];
}
}
#pragma unroll
for (int rr = 0; rr < 2; ++rr) {
const int row = rr * WARP_SIZE + lane;
if (row < n) {
const float val = (row < half) ? x_low : x_high;
X_batch[row * k + col_idx] = val;
}
}
}
static void solve_tri_f32_cuda(
const float * A,
const float * B,
float * X,
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t nb02,
size_t nb03,
size_t nb12,
size_t nb13,
size_t nb2,
size_t nb3,
cudaStream_t stream) {
dim3 threads(WARP_SIZE, k);
dim3 grid(ne02 * ne03);
if (n == 64) {
switch (k) {
case 32:
solve_tri_f32_fast<64, 32><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 16:
solve_tri_f32_fast<64, 16><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 14:
solve_tri_f32_fast<64, 14><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 12:
solve_tri_f32_fast<64, 12><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 10:
solve_tri_f32_fast<64, 10><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 8:
solve_tri_f32_fast<64, 8><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 6:
solve_tri_f32_fast<64, 6><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 4:
solve_tri_f32_fast<64, 4><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 2:
solve_tri_f32_fast<64, 2><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 1:
solve_tri_f32_fast<64, 1><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
default:
solve_tri_f32_fast<0, 0><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
} else {
solve_tri_f32_fast<0, 0><<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
}
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(src0->ne[0] == src0->ne[1]);
GGML_ASSERT(src0->ne[1] == src1->ne[1]);
GGML_ASSERT(src0->ne[2] == src1->ne[2]);
GGML_ASSERT(src0->ne[3] == src1->ne[3]);
const int n = src0->ne[0];
const int k = src1->ne[0];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
solve_tri_f32_cuda(
(const float *) src0->data,
(const float *) src1->data,
(float *) dst->data,
n, k,
ne02, ne03,
src0->nb[2] / sizeof(float),
src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float),
src1->nb[3] / sizeof(float),
dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float),
ctx.stream());
} else {
solve_tri_f32_cublas(
ctx,
(const float *) src0->data,
(const float *) src1->data,
(float *) dst->data,
n, k,
ne02, ne03,
src0->nb[2] / sizeof(float),
src0->nb[3] / sizeof(float),
dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float),
ctx.stream());
}
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,268 @@
#include "ssm-conv.cuh"
#define CUDA_SSM_CONV_BLOCK_SIZE 256
static __global__ void ssm_conv_init_states_f32_nc4(
const float * src0,
float * state,
int nr,
int n_kv) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int seq = blockIdx.y;
if (row >= nr || seq >= n_kv) {
return;
}
const float * src_row = src0 + (size_t) seq * nr * 3 + (size_t) row * 3;
float * state_row = state + (size_t) seq * nr * 4 + (size_t) row * 4;
state_row[1] = src_row[0];
state_row[2] = src_row[1];
state_row[3] = src_row[2];
}
static __global__ void ssm_conv_init_states_f32(
const float * src0,
float * state,
int nc,
int nr,
int n_kv) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int seq = blockIdx.y;
if (row >= nr || seq >= n_kv) {
return;
}
const float * src_row = src0 + (size_t) seq * nr * (nc - 1) + (size_t) row * (nc - 1);
float * state_row = state + (size_t) seq * nr * nc + (size_t) row * nc;
for (int i0 = 0; i0 < nc - 1; ++i0) {
state_row[1 + i0] = src_row[i0];
}
}
static __global__ void ssm_conv_f32_kernel(
const float * src0,
const float * src1,
const float * src2,
const int32_t * src3,
float * dst_x,
float * dst_state,
int nc,
int nr,
int n_t,
int n_kv,
int src1_nb1,
int src3_nb1) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= nr) {
return;
}
const float * c_row = src2 + (size_t) row * nc;
for (int t = 0; t < n_t; ++t) {
const int32_t * sq = src3 + (size_t) t * src3_nb1;
const int seq0 = sq[0];
if (seq0 < 0 || seq0 >= n_kv) {
continue;
}
float * state_row = dst_state + (size_t) seq0 * nr * nc + (size_t) row * nc;
const float * src_state_row;
if (t == 0) {
src_state_row = src0 + (size_t) seq0 * nr * (nc - 1) + (size_t) row * (nc - 1);
} else {
src_state_row = state_row + 1;
}
for (int i0 = 0; i0 < nc - 1; ++i0) {
state_row[i0] = src_state_row[i0];
}
state_row[nc - 1] = src1[row + (size_t) t * src1_nb1];
for (int i3 = 1; i3 < n_kv; ++i3) {
const int seq = sq[i3];
if (seq < 0 || seq >= n_kv) {
break;
}
float * state_row_copy = dst_state + (size_t) seq * nr * nc + (size_t) row * nc;
for (int i0 = 0; i0 < nc; ++i0) {
state_row_copy[i0] = state_row[i0];
}
}
float sumf = 0.0f;
for (int i0 = 0; i0 < nc; ++i0) {
sumf += state_row[i0] * c_row[i0];
}
dst_x[row + (size_t) t * nr] = sumf;
}
}
template <bool has_multi_seq>
static __global__ void ssm_conv_f32_kernel_nc4(
const float * src0,
const float * src1,
const float * src2,
const int32_t * src3,
float * dst_x,
float * dst_state,
int nr,
int n_t,
int n_kv,
int src1_nb1,
int src3_nb1) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= nr) {
return;
}
const float * c_row = src2 + (size_t) row * 4;
const float c0 = c_row[0];
const float c1 = c_row[1];
const float c2 = c_row[2];
const float c3 = c_row[3];
for (int t = 0; t < n_t; ++t) {
const int32_t * sq = src3 + (size_t) t * src3_nb1;
const int seq0 = sq[0];
if (seq0 < 0 || seq0 >= n_kv) {
continue;
}
float * state_row = dst_state + (size_t) seq0 * nr * 4 + (size_t) row * 4;
const float * src_state_row;
if (t == 0) {
src_state_row = src0 + (size_t) seq0 * nr * 3 + (size_t) row * 3;
} else {
src_state_row = state_row + 1;
}
const float s0 = src_state_row[0];
const float s1 = src_state_row[1];
const float s2 = src_state_row[2];
const float x = src1[row + (size_t) t * src1_nb1];
state_row[0] = s0;
state_row[1] = s1;
state_row[2] = s2;
state_row[3] = x;
if constexpr (has_multi_seq) {
for (int i3 = 1; i3 < n_kv; ++i3) {
const int seq = sq[i3];
if (seq < 0 || seq >= n_kv) {
break;
}
float * state_row_copy = dst_state + (size_t) seq * nr * 4 + (size_t) row * 4;
state_row_copy[0] = s0;
state_row_copy[1] = s1;
state_row_copy[2] = s2;
state_row_copy[3] = x;
}
}
dst_x[row + (size_t) t * nr] = s0 * c0 + s1 * c1 + s2 * c2 + x * c3;
}
}
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // conv_state: [d_conv - 1, d_inner, n_kv]
const ggml_tensor * src1 = dst->src[1]; // x: [d_inner, n_tokens]
const ggml_tensor * src2 = dst->src[2]; // conv1d.weight: [d_conv, d_inner]
const ggml_tensor * src3 = dst->src[3]; // state_seq: [n_kv, n_tokens]
const int nc = src2->ne[0];
const int nr = src0->ne[1];
const int n_t = src1->ne[1];
const int n_kv = src0->ne[2];
GGML_ASSERT((int64_t) nr * n_t + (int64_t) nc * nr * n_kv == ggml_nelements(dst));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src3->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
GGML_ASSERT(src2->nb[1] == src2->ne[0] * sizeof(float));
GGML_ASSERT(src2->nb[2] == src2->ne[1] * src2->ne[0] * sizeof(float));
GGML_ASSERT(src2->ne[0] == src0->ne[0] + 1);
GGML_ASSERT(src2->ne[1] == src0->ne[1]);
GGML_ASSERT(src1->ne[0] == src0->ne[1]);
GGML_ASSERT(src3->ne[0] == src0->ne[2]);
GGML_ASSERT(src3->ne[1] == src1->ne[1]);
float * dst_data = (float *) dst->data;
float * dst_x = dst_data;
float * dst_state = dst_data + (size_t) nr * n_t;
const dim3 block_dims(CUDA_SSM_CONV_BLOCK_SIZE, 1, 1);
const dim3 row_grid((nr + CUDA_SSM_CONV_BLOCK_SIZE - 1) / CUDA_SSM_CONV_BLOCK_SIZE, 1, 1);
if (n_kv > 1) {
const dim3 init_grid(row_grid.x, n_kv, 1);
if (nc == 4) {
ssm_conv_init_states_f32_nc4<<<init_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
dst_state,
nr, n_kv);
} else {
ssm_conv_init_states_f32<<<init_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
dst_state,
nc, nr, n_kv);
}
}
if (nc == 4) {
if (n_kv > 1) {
ssm_conv_f32_kernel_nc4<true><<<row_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
(const int32_t *) src3->data,
dst_x,
dst_state,
nr, n_t, n_kv,
src1->nb[1] / sizeof(float),
src3->nb[1] / sizeof(int32_t));
} else {
ssm_conv_f32_kernel_nc4<false><<<row_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
(const int32_t *) src3->data,
dst_x,
dst_state,
nr, n_t, n_kv,
src1->nb[1] / sizeof(float),
src3->nb[1] / sizeof(int32_t));
}
} else {
ssm_conv_f32_kernel<<<row_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
(const int32_t *) src3->data,
dst_x,
dst_state,
nc, nr, n_t, n_kv,
src1->nb[1] / sizeof(float),
src3->nb[1] / sizeof(int32_t));
}
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

120
ggml/src/ggml-cuda/tri.cu Normal file
View File

@@ -0,0 +1,120 @@
#include "tri.cuh"
#include "convert.cuh"
template<typename T, bool prefix_keep, int add_to_split>
static __global__ void tri_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
const int64_t split_point = i1 + add_to_split;
(void) nb00;
(void) nb0;
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
const T * src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03;
T * dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3;
if constexpr (prefix_keep) {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
}
} else {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
}
}
template<typename T>
static void tri_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const ggml_tri_type ttype,
cudaStream_t stream) {
dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
dim3 grid_dims(ne01, ne02, ne03);
const size_t type_size = sizeof(T);
const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
if (prefix_keep) {
if (add_to_split == 0) {
tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
} else {
if (add_to_split == 0) {
tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
}
}
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tri_type ttype = static_cast<ggml_tri_type>(((const int32_t *) dst->op_params)[0]);
GGML_ASSERT(src0->type == dst->type);
switch (src0->type) {
case GGML_TYPE_F32:
tri_cuda(
(const float *) src0->data, (float *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, ctx.stream()
);
break;
case GGML_TYPE_F16:
tri_cuda(
(const half *) src0->data, (half *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, ctx.stream()
);
break;
default:
GGML_ABORT("fatal error");
}
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_TRI_BLOCK_SIZE 256
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -641,6 +641,10 @@ static __device__ __forceinline__ float op_exp(float x) {
return expf(x);
}
static __device__ __forceinline__ float op_softplus(float x) {
return (x > 20.0f) ? x : logf(1.0f + expf(x));
}
static __device__ __forceinline__ float op_sin(float x) {
return sinf(x);
}
@@ -737,6 +741,10 @@ void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_exp>(ctx, dst);
}
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_softplus>(ctx, dst);
}
// === gated ops
template <float (*op)(float), typename T>
@@ -848,4 +856,3 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_elu>(ctx, dst);
}

View File

@@ -53,6 +53,8 @@ void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -3288,6 +3288,9 @@ inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) {
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
inline static float ggml_compute_softplus_f32(const float x) { return x > 20.0f ? x : logf(1.0f + expf(x)); }
inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
inline static void ggml_vec_softplus_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = ggml_compute_softplus_f32(x[i]); }
// TODO: optimize performance
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@@ -4196,6 +4199,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"LOG",
"SUM",
"SUM_ROWS",
"CUMSUM",
"MEAN",
"ARGMAX",
"REPEAT",
@@ -4206,6 +4210,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"RMS_NORM",
"RMS_NORM_BACK",
"GROUP_NORM",
"L2_NORM",
"FUSED_RMS_NORM",
"FUSED_MUL_UNARY",
"MULTI_ADD",
@@ -4248,6 +4253,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"PAD",
"ARANGE",
"TIMESTEP_EMBEDDING",
"TRI",
"FILL",
"ARGSORT",
"ARGSORT_THRESH",
"GROUPED_TOPK",
@@ -4266,6 +4273,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"WIN_UNPART",
"GET_REL_POS",
"ADD_REL_POS",
"SOLVE_TRI",
"UNARY",
@@ -4290,7 +4298,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"FUSED_NORM",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -4308,6 +4316,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"log(x)",
"Σx",
"Σx_k",
"cumsum(x)",
"Σx/n",
"argmax(x)",
"repeat(x)",
@@ -4318,6 +4327,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"rms_norm(x)",
"rms_norm_back(x)",
"group_norm(x)",
"l2_norm(x)",
"fused_rms_norm(x)",
"fused_mul_unary(x)",
"x1+x2+x3+...",
@@ -4360,6 +4370,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"pad(x)",
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
"tri(x)",
"fill(x)",
"argsort(x)",
"argsort_thresh(x)",
"grouped_topk(x)",
@@ -4378,6 +4390,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"win_unpart(x)",
"get_rel_pos(x)",
"add_rel_pos(x)",
"solve_tri(x)",
"unary(x)",
@@ -4402,7 +4415,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"norm(x,y)",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -4416,17 +4429,19 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"ELU",
"RELU",
"SIGMOID",
"GELU",
"GELU_ERF",
"GELU_QUICK",
"SILU",
"HARDSWISH",
"HARDSIGMOID",
"EXP",
"SOFTPLUS",
"SWIGLU",
"SWIGLU_OAI",
"GELU_ERF",
"GELU",
};
static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
static_assert(GGML_UNARY_OP_COUNT == 18, "GGML_UNARY_OP_COUNT != 18");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -6701,6 +6716,28 @@ struct ggml_tensor * ggml_sum_rows(
return result;
}
// ggml_cumsum
struct ggml_tensor * ggml_cumsum(
struct ggml_context * ctx,
struct ggml_tensor * a) {
bool is_node = false;
GGML_ASSERT(a->type == GGML_TYPE_F32);
if (a->grad) {
is_node = true;
}
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_CUMSUM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_mean
struct ggml_tensor * ggml_mean(
@@ -6769,6 +6806,33 @@ struct ggml_tensor * ggml_repeat(
return result;
}
struct ggml_tensor * ggml_repeat_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
const bool can_repeat = ggml_is_empty(a) || (
(ne0 % a->ne[0] == 0) &&
(ne1 % a->ne[1] == 0) &&
(ne2 % a->ne[2] == 0) &&
(ne3 % a->ne[3] == 0)
);
GGML_ASSERT(can_repeat);
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
result->op = GGML_OP_REPEAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_repeat_back
struct ggml_tensor * ggml_repeat_back(
@@ -6967,6 +7031,20 @@ struct ggml_tensor * ggml_sigmoid_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
}
// ggml_softplus
struct ggml_tensor * ggml_softplus(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS);
}
struct ggml_tensor * ggml_softplus_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS);
}
// ggml_gelu
struct ggml_tensor * ggml_gelu(
@@ -7110,6 +7188,20 @@ struct ggml_tensor * ggml_hardsigmoid(
return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
}
// ggml exp
struct ggml_tensor * ggml_exp(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
}
struct ggml_tensor * ggml_exp_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
}
// =============== from mainline begin =====================================
// ggml_glu
@@ -7531,6 +7623,45 @@ struct ggml_tensor * ggml_group_norm_inplace(
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
}
// ggml_l2_norm
static struct ggml_tensor * ggml_l2_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps,
bool inplace) {
bool is_node = false;
if (!inplace && a->grad) {
GGML_ABORT("fatal error"); // TODO: implement backward
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params_f32(result, 0, eps);
result->op = GGML_OP_L2_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_l2_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps) {
return ggml_l2_norm_impl(ctx, a, eps, false);
}
struct ggml_tensor * ggml_l2_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps) {
return ggml_l2_norm_impl(ctx, a, eps, true);
}
// ggml_mul_mat
struct ggml_tensor * ggml_mul_mat(
@@ -9698,6 +9829,74 @@ struct ggml_tensor * ggml_timestep_embedding(
return result;
}
// ggml_tri
struct ggml_tensor * ggml_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_tri_type type) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(a->ne[0] == a->ne[1]);
bool is_node = false;
if (a->grad) {
GGML_ABORT("fatal error"); // TODO: implement backward
is_node = true;
}
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
ggml_set_op_params_i32(result, 0, (int32_t) type);
result->op = GGML_OP_TRI;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_fill
static struct ggml_tensor * ggml_fill_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c,
bool inplace) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(a));
bool is_node = false;
if (!inplace && a->grad) {
GGML_ABORT("fatal error"); // TODO: implement backward
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params_f32(result, 0, c);
result->op = GGML_OP_FILL;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_fill(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c) {
return ggml_fill_impl(ctx, a, c, false);
}
struct ggml_tensor * ggml_fill_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c) {
return ggml_fill_impl(ctx, a, c, true);
}
// ggml_argsort
struct ggml_tensor * ggml_argsort(
@@ -9980,6 +10179,44 @@ struct ggml_tensor * ggml_flash_attn_back(
return result;
}
// ggml_solve_tri
struct ggml_tensor * ggml_solve_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
bool left,
bool lower,
bool uni) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(a->ne[0] == a->ne[1]);
GGML_ASSERT(a->ne[1] == b->ne[1]);
GGML_ASSERT(a->ne[2] == b->ne[2]);
GGML_ASSERT(a->ne[3] == b->ne[3]);
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(ggml_is_contiguous(b));
GGML_ASSERT(lower && left && !uni); // TODO: support other variants
bool is_node = false;
if (a->grad || b->grad) {
GGML_ABORT("fatal error"); // TODO: implement backward
is_node = true;
}
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
result->op = GGML_OP_SOLVE_TRI;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_ssm_conv
struct ggml_tensor * ggml_ssm_conv(
@@ -13789,6 +14026,62 @@ static void ggml_compute_forward_sum(
}
}
// ggml_compute_forward_cumsum
static void ggml_compute_forward_cumsum_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(dst->nb[0] == sizeof(float));
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ne01*ne02*ne03;
for (int64_t ir = ith; ir < nr; ir += nth) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = ir - i03*ne02*ne01 - i02*ne01;
const float * src_row = (const float *) ((const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
float * dst_row = ( float *) (( char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
float acc = 0.0f;
for (int64_t i00 = 0; i00 < ne00; ++i00) {
acc += src_row[i00];
dst_row[i00] = acc;
}
}
}
static void ggml_compute_forward_cumsum(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_cumsum_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_sum_rows
static void ggml_compute_forward_sum_rows_f32(
@@ -14665,6 +14958,100 @@ static void ggml_compute_forward_sigmoid(
}
}
// ggml_compute_forward_exp
static void ggml_compute_forward_exp_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
const int dr = (nr + nth - 1)/nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_exp_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
}
}
static void ggml_compute_forward_exp(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_exp_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_softplus
static void ggml_compute_forward_softplus_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
const int dr = (nr + nth - 1)/nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_softplus_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
}
}
static void ggml_compute_forward_softplus(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_softplus_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_gelu
static void ggml_compute_forward_gelu_f32(
@@ -14724,6 +15111,102 @@ static void ggml_compute_forward_gelu(
}
}
// ggml_compute_forward_fill
static void ggml_compute_forward_fill_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
const float c = ggml_get_op_params_f32(dst, 0);
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ne1*ne2*ne3;
for (int64_t ir = ith; ir < nr; ir += nth) {
const int64_t i03 = ir/(ne2*ne1);
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
const int64_t i01 = ir - i03*ne2*ne1 - i02*ne1;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
ggml_vec_set_f32(ne0, dst_ptr, c);
}
}
static void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
ggml_compute_forward_fill_f32(params, dst);
}
// ggml_compute_forward_tri
static bool ggml_tri_lower_pred(int i, int r) {
return i < r;
}
static bool ggml_tri_lower_diag_pred(int i, int r) {
return i <= r;
}
static bool ggml_tri_upper_pred(int i, int r) {
return i > r;
}
static bool ggml_tri_upper_diag_pred(int i, int r) {
return i >= r;
}
static void ggml_compute_forward_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const enum ggml_tri_type ttype = (enum ggml_tri_type) ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_TENSOR_UNARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ne01*ne02*ne03;
bool (*bipred)(int, int);
switch (ttype) {
case GGML_TRI_TYPE_LOWER: bipred = ggml_tri_lower_pred; break;
case GGML_TRI_TYPE_LOWER_DIAG: bipred = ggml_tri_lower_diag_pred; break;
case GGML_TRI_TYPE_UPPER: bipred = ggml_tri_upper_pred; break;
case GGML_TRI_TYPE_UPPER_DIAG: bipred = ggml_tri_upper_diag_pred; break;
default: GGML_ABORT("invalid tri type");
}
for (int64_t ir = ith; ir < nr; ir += nth) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = ir - i03*ne02*ne01 - i02*ne01;
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
for (int i0 = 0; i0 < ne0; ++i0) {
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
}
}
}
static void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_tri_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_gelu_quick
static void ggml_compute_forward_gelu_quick_f32(
@@ -16781,6 +17264,65 @@ static void ggml_compute_forward_mul_mat_up_gate(
}
#endif
// ggml_compute_forward_l2_norm
static void ggml_compute_forward_l2_norm_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
const float eps = ggml_get_op_params_f32(dst, 0);
GGML_ASSERT(eps >= 0.0f);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (const float *) ((const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float) (x[i00] * x[i00]);
}
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
}
static void ggml_compute_forward_l2_norm(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_l2_norm_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_out_prod
static void ggml_compute_forward_out_prod_f32(
@@ -21515,6 +22057,75 @@ static void ggml_compute_forward_ssm_scan(
}
}
// ggml_compute_forward_solve_tri
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ne00 == ne01);
GGML_ASSERT(ne0 == ne10);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
const int ith = params->ith;
const int nth = params->nth;
const int64_t k = ne10;
const int64_t n = ne11;
const int64_t nr = ne02*ne03*k;
const int64_t dr = (nr + nth - 1)/nth;
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
const float * A = (const float *) src0->data;
const float * B = (const float *) src1->data;
float * X = ( float *) dst->data;
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*k);
const int64_t i02 = (ir - i03*ne02*k)/k;
const int64_t i01 = ir - i03*ne02*k - i02*k;
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
for (int64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f;
for (int64_t t = 0; t < i00; ++t) {
sum += A_batch[i00*n + t] * X_batch[t*k + i01];
}
const float diag = A_batch[i00*n + i00];
GGML_ASSERT(diag != 0.0f);
X_batch[i00*k + i01] = (B_batch[i00*k + i01] - sum) / diag;
}
}
}
static void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_compute_forward_solve_tri_f32(params, dst);
} else {
GGML_ABORT("fatal error");
}
}
// ggml_compute_forward_win_part
static void ggml_compute_forward_win_part_f32(
@@ -21708,6 +22319,14 @@ static void ggml_compute_forward_unary(
{
ggml_compute_forward_hardsigmoid(params, dst);
} break;
case GGML_UNARY_OP_EXP:
{
ggml_compute_forward_exp(params, dst);
} break;
case GGML_UNARY_OP_SOFTPLUS:
{
ggml_compute_forward_softplus(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
@@ -22938,6 +23557,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
ggml_compute_forward_sum_rows(params, tensor);
}
} break;
case GGML_OP_CUMSUM:
{
ggml_compute_forward_cumsum(params, tensor);
} break;
case GGML_OP_MEAN:
{
ggml_compute_forward_mean(params, tensor);
@@ -22986,6 +23609,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
{
ggml_compute_forward_group_norm(params, tensor);
} break;
case GGML_OP_L2_NORM:
{
ggml_compute_forward_l2_norm(params, tensor);
} break;
case GGML_OP_MUL_MAT:
{
i = ggml_compute_forward_mul_mat(params, tensor, cgraph, i);
@@ -23157,6 +23784,14 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
{
ggml_compute_forward_timestep_embedding(params, tensor);
} break;
case GGML_OP_TRI:
{
ggml_compute_forward_tri(params, tensor);
} break;
case GGML_OP_FILL:
{
ggml_compute_forward_fill(params, tensor);
} break;
case GGML_OP_ARGSORT:
{
if (false && i + 5 < cgraph->n_nodes &&
@@ -23202,6 +23837,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
{
ggml_compute_forward_ssm_scan(params, tensor);
} break;
case GGML_OP_SOLVE_TRI:
{
ggml_compute_forward_solve_tri(params, tensor);
} break;
case GGML_OP_WIN_PART:
{
ggml_compute_forward_win_part(params, tensor);
@@ -24229,6 +24868,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_CUMSUM:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_L2_NORM:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_TIMESTEP_EMBEDDING:
{
GGML_ABORT("fatal error"); // TODO: not implemented
@@ -24249,6 +24896,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_TRI:
case GGML_OP_FILL:
case GGML_OP_SOLVE_TRI:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_FLASH_ATTN_EXT:
{
struct ggml_tensor * flash_grad = NULL;
@@ -24391,6 +25044,24 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
case GGML_UNARY_OP_EXP:
{
if (src0->grad) {
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx, tensor, tensor->grad),
zero_table);
}
} break;
case GGML_UNARY_OP_SOFTPLUS:
{
if (src0->grad) {
src0->grad = ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx, tensor->grad, ggml_sigmoid(ctx, src0)),
zero_table);
}
} break;
default:
GGML_ABORT("fatal error");
}
@@ -24870,6 +25541,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_ACC:
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
case GGML_OP_FILL:
case GGML_OP_MULTI_ADD:
case GGML_OP_MUL_MULTI_ADD:
case GGML_OP_HADAMARD:
@@ -24908,6 +25582,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_SWIGLU:
case GGML_UNARY_OP_SWIGLU_OAI:
{
@@ -24946,12 +25622,14 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_FUSED_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_GROUP_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_CONCAT:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_MOE_FUSED_UP_GATE:
case GGML_OP_FUSED_UP_GATE:
case GGML_OP_OUT_PROD:
case GGML_OP_SOLVE_TRI:
{
n_tasks = n_threads;
} break;