mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Adding Ling/Ring (a.k.a., Bailing-MoE2) support (#833)
* Adding Ling/Ring (a.k.a., Bailing-MoE2) * Add expert group selection (not working, so turned off) * BailingMoE2 conversion * WIP * Bits and pieces --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -630,6 +630,7 @@ extern "C" {
|
||||
GGML_OP_TRANSPOSE,
|
||||
GGML_OP_GET_ROWS,
|
||||
GGML_OP_GET_ROWS_BACK,
|
||||
GGML_OP_SET_ROWS,
|
||||
GGML_OP_DIAG,
|
||||
GGML_OP_DIAG_MASK_INF,
|
||||
GGML_OP_DIAG_MASK_ZERO,
|
||||
@@ -1559,6 +1560,19 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
float s);
|
||||
|
||||
// x = s * a + b
|
||||
GGML_API struct ggml_tensor * ggml_scale_bias(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_softcap(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -1781,6 +1795,23 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c);
|
||||
|
||||
// a TD [n_embd, ne1, ne2, ne3]
|
||||
// b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
|
||||
// c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
|
||||
//
|
||||
// undefined behavior if destination rows overlap
|
||||
//
|
||||
// broadcast:
|
||||
// ne2 % ne11 == 0
|
||||
// ne3 % ne12 == 0
|
||||
//
|
||||
// return view(a)
|
||||
GGML_API struct ggml_tensor * ggml_set_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // destination
|
||||
struct ggml_tensor * b, // source
|
||||
struct ggml_tensor * c); // row indices
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_diag(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
@@ -44,6 +44,8 @@
|
||||
#include "ggml-cuda/topk-moe.cuh"
|
||||
#include "ggml-cuda/conv2d.cuh"
|
||||
#include "ggml-cuda/conv2d-dw.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml-cuda/argmax.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@@ -3105,12 +3107,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr;
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ARGMAX:
|
||||
ggml_cuda_argmax(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_REPEAT:
|
||||
ggml_cuda_op_repeat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_cuda_op_get_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_cuda_op_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
@@ -4204,6 +4212,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
||||
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
|
||||
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
@@ -4260,6 +4276,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_ARGMAX:
|
||||
return true;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_CONCAT:
|
||||
|
||||
90
ggml/src/ggml-cuda/argmax.cu
Normal file
90
ggml/src/ggml-cuda/argmax.cu
Normal file
@@ -0,0 +1,90 @@
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
|
||||
#include "argmax.cuh"
|
||||
#include "common.cuh"
|
||||
|
||||
static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
|
||||
const int64_t row = blockIdx.x;
|
||||
|
||||
float maxval = -FLT_MAX;
|
||||
int argmax = -1;
|
||||
const float * rowx = x + row * ncols;
|
||||
|
||||
for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
|
||||
const float val = rowx[col];
|
||||
if (val > maxval) {
|
||||
maxval = val;
|
||||
argmax = col;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
||||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
||||
if (val > maxval) {
|
||||
maxval = val;
|
||||
argmax = col;
|
||||
}
|
||||
}
|
||||
|
||||
const int n_warps = blockDim.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
if (n_warps > 1) {
|
||||
constexpr int max_warps = 1024 / WARP_SIZE;
|
||||
__shared__ float shared_maxval[max_warps];
|
||||
__shared__ int shared_argmax[max_warps];
|
||||
if (lane_id == 0) {
|
||||
shared_maxval[warp_id] = maxval;
|
||||
shared_argmax[warp_id] = argmax;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {
|
||||
if (lane_id < n_warps) {
|
||||
maxval = shared_maxval[lane_id];
|
||||
argmax = shared_argmax[lane_id];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
||||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
||||
if (val > maxval) {
|
||||
maxval = val;
|
||||
argmax = col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (warp_id == 0 && lane_id == 0) {
|
||||
dst[row] = argmax;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
int32_t * dst_d = (int32_t *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int64_t num_blocks = nrows;
|
||||
const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
|
||||
const dim3 blocks_dim(num_threads, 1, 1);
|
||||
const dim3 blocks_num(num_blocks, 1, 1);
|
||||
|
||||
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
|
||||
}
|
||||
3
ggml/src/ggml-cuda/argmax.cuh
Normal file
3
ggml/src/ggml-cuda/argmax.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -13,9 +13,20 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||
b = tmp;
|
||||
}
|
||||
|
||||
template<ggml_sort_order order>
|
||||
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad,
|
||||
int min_experts, float thresh_experts) {
|
||||
struct store_ser {
|
||||
constexpr static bool has_thresh = true;
|
||||
int min_experts;
|
||||
float thresh_experts;
|
||||
store_ser(int min, float thresh) : min_experts(min), thresh_experts(thresh) {}
|
||||
};
|
||||
|
||||
struct store {
|
||||
constexpr static bool has_thresh = false;
|
||||
};
|
||||
|
||||
template<ggml_sort_order order, typename Store>
|
||||
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, Store s) {
|
||||
// int min_experts, float thresh_experts) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
@@ -58,19 +69,30 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
|
||||
}
|
||||
}
|
||||
|
||||
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
|
||||
if constexpr (Store::has_thresh) {
|
||||
__syncthreads();
|
||||
float max_val = x_row[dst_row[0]];
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1;
|
||||
dst[row * ncols + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// copy the result to dst without the padding
|
||||
} else {
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
//if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
|
||||
// __syncthreads();
|
||||
// float max_val = x_row[dst_row[0]];
|
||||
// if (col < ncols) {
|
||||
// dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1;
|
||||
// }
|
||||
//}
|
||||
//else {
|
||||
// // copy the result to dst without the padding
|
||||
// if (col < ncols) {
|
||||
// dst[row * ncols + col] = dst_row[col];
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
static int next_power_of_2(int x) {
|
||||
@@ -94,9 +116,21 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
|
||||
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
|
||||
{min_experts, thresh_experts});
|
||||
} else {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, {});
|
||||
}
|
||||
//k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
|
||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
|
||||
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
|
||||
{min_experts, thresh_experts});
|
||||
} else {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, {});
|
||||
}
|
||||
//k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
@@ -422,7 +422,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
} else
|
||||
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
||||
{
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
@@ -473,6 +477,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
|
||||
// This is needed for MLA with mla=2 when using q8_0 cache.
|
||||
transpose_q8_0(ctx, src0, src1);
|
||||
@@ -498,7 +506,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
return nullptr;
|
||||
// Prioritize CUDA graph compatibility over direct memory copy optimization.
|
||||
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
@@ -545,6 +559,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
|
||||
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
|
||||
return (void *)transpose_q8_0;
|
||||
} else {
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
#include "scale.cuh"
|
||||
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
#define MAX_GRIDDIM_X 0x7FFFFFFF
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
|
||||
int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
|
||||
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
|
||||
|
||||
for (int64_t i = tid; i < nelements; i += stride) {
|
||||
dst[i] = scale * x[i] + bias;
|
||||
}
|
||||
|
||||
dst[i] = scale * x[i];
|
||||
}
|
||||
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||
}
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
|
||||
const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||
// Whehn will we be scaling tensors with more than 2^39 elements?
|
||||
//scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
|
||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements); }
|
||||
|
||||
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
@@ -25,7 +27,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
float bias;
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
|
||||
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
277
ggml/src/ggml-cuda/set-rows.cu
Normal file
277
ggml/src/ggml-cuda/set-rows.cu
Normal file
@@ -0,0 +1,277 @@
|
||||
#include "set-rows.cuh"
|
||||
#include "cpy-utils.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
||||
|
||||
// Generic quantized set_rows kernel template
|
||||
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
|
||||
static __global__ void k_set_rows_quant(
|
||||
const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i_base = i * qk;
|
||||
const int64_t i03 = i_base / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
|
||||
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||
block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
|
||||
|
||||
const float * src_block = src0_row + i00;
|
||||
block_type * dst_block = dst_row_ptr + i00 / qk;
|
||||
|
||||
quantize_func(src_block, dst_block);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
// Template dispatch function for quantized set_rows
|
||||
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
|
||||
static void set_rows_cuda_quant(
|
||||
const float * src0_d, const idx_t * src1_d, block_type * dst_d,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne00 % qk == 0);
|
||||
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
|
||||
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
|
||||
const dim3 grid_size(num_blocks);
|
||||
|
||||
const int64_t s01 = nb01/sizeof(float);
|
||||
const int64_t s02 = nb02/sizeof(float);
|
||||
const int64_t s03 = nb03/sizeof(float);
|
||||
const int64_t s10 = nb10/sizeof(idx_t);
|
||||
const int64_t s11 = nb11/sizeof(idx_t);
|
||||
const int64_t s12 = nb12/sizeof(idx_t);
|
||||
const int64_t s1 = nb1;
|
||||
const int64_t s2 = nb2;
|
||||
const int64_t s3 = nb3;
|
||||
|
||||
if (ne_total > 0) {
|
||||
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename src_t, typename idx_t, typename dst_t>
|
||||
static __global__ void k_set_rows(
|
||||
const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = i / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
|
||||
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
|
||||
|
||||
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
template<typename src_t, typename idx_t, typename dst_t>
|
||||
static void set_rows_cuda(
|
||||
const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
|
||||
const dim3 grid_size(num_blocks);
|
||||
|
||||
|
||||
const int64_t s01 = nb01/sizeof(src_t);
|
||||
const int64_t s02 = nb02/sizeof(src_t);
|
||||
const int64_t s03 = nb03/sizeof(src_t);
|
||||
const int64_t s10 = nb10/sizeof(idx_t);
|
||||
const int64_t s11 = nb11/sizeof(idx_t);
|
||||
const int64_t s12 = nb12/sizeof(idx_t);
|
||||
const int64_t s1 = nb1/sizeof(dst_t);
|
||||
const int64_t s2 = nb2/sizeof(dst_t);
|
||||
const int64_t s3 = nb3/sizeof(dst_t);
|
||||
|
||||
if (ne_total > 0) {
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename src_t, typename idx_t>
|
||||
static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const src_t * src0_d = (const src_t *)src0->data;
|
||||
const idx_t * src1_d = (const idx_t *)src1->data;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (float*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (half*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_BF16) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (nv_bfloat16*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q4_0) {
|
||||
set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
|
||||
src0_d, src1_d, (block_q4_0*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q4_1) {
|
||||
set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
|
||||
src0_d, src1_d, (block_q4_1*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q5_0) {
|
||||
set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
|
||||
src0_d, src1_d, (block_q5_0*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q5_1) {
|
||||
set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
|
||||
src0_d, src1_d, (block_q5_1*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_Q8_0) {
|
||||
set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
|
||||
src0_d, src1_d, (block_q8_0*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_IQ4_NL) {
|
||||
set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
|
||||
src0_d, src1_d, (block_iq4_nl*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
|
||||
|
||||
if (src1->type == GGML_TYPE_I64) {
|
||||
set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
|
||||
} else {
|
||||
set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
|
||||
}
|
||||
}
|
||||
7
ggml/src/ggml-cuda/set-rows.cuh
Normal file
7
ggml/src/ggml-cuda/set-rows.cuh
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_SET_ROWS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
238
ggml/src/ggml.c
238
ggml/src/ggml.c
@@ -3206,6 +3206,53 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
|
||||
#if defined(GGML_USE_ACCELERATE)
|
||||
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
|
||||
#elif defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
// scalar ; TODO: Write SVE code
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
for (int i = 0, avl; i < n; i += avl) {
|
||||
avl = __riscv_vsetvl_e32m8(n - i);
|
||||
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
|
||||
vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);
|
||||
vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);
|
||||
__riscv_vse32_v_f32m8(&y[i], ny, avl);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
|
||||
GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
|
||||
|
||||
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
||||
ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
@@ -4185,6 +4232,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"TRANSPOSE",
|
||||
"GET_ROWS",
|
||||
"GET_ROWS_BACK",
|
||||
"SET_ROWS",
|
||||
"DIAG",
|
||||
"DIAG_MASK_INF",
|
||||
"DIAG_MASK_ZERO",
|
||||
@@ -4239,7 +4287,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
|
||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -4287,6 +4335,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"transpose(x)",
|
||||
"get_rows(x)",
|
||||
"get_rows_back(x)",
|
||||
"set_rows(x)",
|
||||
"diag(x)",
|
||||
"diag_mask_inf(x)",
|
||||
"diag_mask_zero(x)",
|
||||
@@ -4341,7 +4390,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"glu(x),"
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
|
||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -7544,6 +7593,7 @@ static struct ggml_tensor * ggml_scale_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_is_padded_1d(a));
|
||||
|
||||
@@ -7555,7 +7605,8 @@ static struct ggml_tensor * ggml_scale_impl(
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, &s, sizeof(s));
|
||||
float params[2] = {s, b};
|
||||
ggml_set_op_params(result, ¶ms, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_SCALE;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
@@ -7568,14 +7619,30 @@ struct ggml_tensor * ggml_scale(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s) {
|
||||
return ggml_scale_impl(ctx, a, s, false);
|
||||
return ggml_scale_impl(ctx, a, s, 0.f, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_scale_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s) {
|
||||
return ggml_scale_impl(ctx, a, s, true);
|
||||
return ggml_scale_impl(ctx, a, s, 0.f, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_scale_bias(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b) {
|
||||
return ggml_scale_impl(ctx, a, s, b, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_scale_bias_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b) {
|
||||
return ggml_scale_impl(ctx, a, s, b, true);
|
||||
}
|
||||
|
||||
// ggml_softcap
|
||||
@@ -8294,6 +8361,36 @@ struct ggml_tensor * ggml_get_rows_back(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_set_rows
|
||||
|
||||
struct ggml_tensor * ggml_set_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c) {
|
||||
GGML_ASSERT(a->ne[0] == b->ne[0]);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
||||
GGML_ASSERT(a->ne[3] == b->ne[3]);
|
||||
GGML_ASSERT(b->ne[1] == c->ne[0]);
|
||||
GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
|
||||
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
|
||||
GGML_ASSERT(c->ne[3] == 1);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(a));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(b));
|
||||
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_SET_ROWS;
|
||||
result->src[0] = b;
|
||||
result->src[1] = c;
|
||||
result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_diag
|
||||
|
||||
struct ggml_tensor * ggml_diag(
|
||||
@@ -16594,8 +16691,9 @@ static void ggml_compute_forward_scale_f32(
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
// scale factor
|
||||
float v;
|
||||
memcpy(&v, dst->op_params, sizeof(float));
|
||||
float s, b;
|
||||
memcpy(&s, (const float *)dst->op_params + 0, sizeof(float));
|
||||
memcpy(&b, (const float *)dst->op_params + 1, sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@@ -16614,12 +16712,21 @@ static void ggml_compute_forward_scale_f32(
|
||||
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
if (dst->data != src0->data) {
|
||||
// src0 is same shape as dst => same indices
|
||||
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
||||
if (b == 0.0f) {
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
if (dst->data != src0->data) {
|
||||
// src0 is same shape as dst => same indices
|
||||
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
||||
}
|
||||
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
|
||||
}
|
||||
} else {
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
ggml_vec_mad1_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*nb1),
|
||||
(float *) ((char *) src0->data + i1*nb1),
|
||||
s, b);
|
||||
}
|
||||
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17455,6 +17562,101 @@ static void ggml_compute_forward_get_rows_back(
|
||||
//}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_set_rows_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ne01;
|
||||
|
||||
assert(ne0 == nc);
|
||||
assert(ne2 == ne02);
|
||||
assert(ne3 == ne03);
|
||||
assert(src0->type == GGML_TYPE_F32);
|
||||
assert(ne02 % ne11 == 0);
|
||||
assert(ne03 % ne12 == 0);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
// rows per thread
|
||||
const int64_t dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
ggml_from_float_t const from_float = type_traits[dst->type].from_float;
|
||||
|
||||
if (src1->type == GGML_TYPE_I64) {
|
||||
for (int64_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (int64_t i = ir0; i < ir1; ++i) {
|
||||
const int64_t i12 = i03%ne12;
|
||||
const int64_t i11 = i02%ne11;
|
||||
const int64_t i10 = i;
|
||||
|
||||
const int64_t i1 = *(int64_t*) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
||||
|
||||
from_float((const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
|
||||
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (src1->type == GGML_TYPE_I32) {
|
||||
for (int64_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (int64_t i = ir0; i < ir1; ++i) {
|
||||
const int64_t i12 = i03%ne12;
|
||||
const int64_t i11 = i02%ne11;
|
||||
const int64_t i10 = i;
|
||||
|
||||
const int64_t i1 = *(int32_t*) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
||||
|
||||
from_float((const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
|
||||
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_set_rows(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
if (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32) {
|
||||
ggml_compute_forward_set_rows_f32(params, dst);
|
||||
} else {
|
||||
GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_diag
|
||||
|
||||
static void ggml_compute_forward_diag_f32(
|
||||
@@ -22208,10 +22410,15 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
{
|
||||
ggml_compute_forward_get_rows(params, tensor);
|
||||
} break;
|
||||
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
{
|
||||
ggml_compute_forward_get_rows_back(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
ggml_compute_forward_set_rows(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_DIAG:
|
||||
{
|
||||
ggml_compute_forward_diag(params, tensor);
|
||||
@@ -22973,7 +23180,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
src0->grad =
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_scale_impl(ctx, tensor->grad, s, false),
|
||||
ggml_scale_impl(ctx, tensor->grad, s, 0.0f, false),
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
@@ -23144,6 +23351,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
// noop
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
}
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
@@ -24010,6 +24221,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
|
||||
// decreases performance with GPU offloading
|
||||
|
||||
Reference in New Issue
Block a user