mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-25 17:09:22 +00:00
cuda: neox works
This commit is contained in:
@@ -3062,6 +3062,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
|
||||
auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr;
|
||||
|
||||
//printf("%4d %s(%s)\n", i, ggml_op_name(dst->op), dst->name);
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ARGMAX:
|
||||
ggml_cuda_argmax(ctx, dst);
|
||||
@@ -3317,6 +3318,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
ggml_cuda_op_rope_back(ctx, dst);
|
||||
case GGML_OP_ROPE_FAST:
|
||||
ggml_cuda_op_rope_fast(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE_CACHE:
|
||||
ggml_cuda_op_rope_cache(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
ggml_cuda_op_im2col(ctx, dst);
|
||||
@@ -4377,6 +4383,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_SOFT_CAP_MAX:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_ROPE_FAST:
|
||||
case GGML_OP_ROPE_CACHE:
|
||||
return true;
|
||||
//case GGML_OP_ROPE:
|
||||
// return ggml_is_contiguous(op->src[0]);
|
||||
|
||||
@@ -121,6 +121,38 @@ static __global__ void rope_neox(
|
||||
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void rope_neox_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int ne2,
|
||||
int s01, int s02, int n_dims) {
|
||||
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||
|
||||
if (i >= ne0*ne1*ne2) {
|
||||
return;
|
||||
}
|
||||
//i = i0 + i1*ne0 + i2*ne0*ne1;
|
||||
int i2 = i / (ne0*ne1); i -= i2*ne0*ne1;
|
||||
int i1 = i / ne0;
|
||||
int i0 = i - i1*ne0;
|
||||
|
||||
const int idst = i2*ne0*ne1 + i1*ne0 + i0/2;
|
||||
const int ix = i2*s02 + i1*s01 + i0/2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = src0[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = src0[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float x0 = src0[ix + 0];
|
||||
const float x1 = src0[ix + n_dims/2];
|
||||
|
||||
const float cos_theta = src1[i2*ne0 + i0 + 0];
|
||||
const float sin_theta = src1[i2*ne0 + i0 + 1];
|
||||
|
||||
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_multi(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
||||
@@ -272,6 +304,45 @@ static void rope_neox_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_neox_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
|
||||
int n_dims, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
||||
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(n_blocks, 1, 1);
|
||||
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne02, s01, s02, n_dims);
|
||||
}
|
||||
|
||||
static void rope_norm_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
|
||||
int n_dims, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
||||
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(n_blocks, 1, 1);
|
||||
// TODO
|
||||
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne02, s01, s02, n_dims);
|
||||
}
|
||||
|
||||
static void rope_multi_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
|
||||
int n_dims, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
||||
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(n_blocks, 1, 1);
|
||||
// TODO
|
||||
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne02, s01, s02, n_dims);
|
||||
}
|
||||
|
||||
static void rope_vision_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
|
||||
int n_dims, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
||||
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(n_blocks, 1, 1);
|
||||
// TODO
|
||||
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne02, s01, s02, n_dims);
|
||||
}
|
||||
|
||||
template<bool forward, typename T>
|
||||
static void rope_multi_cuda(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
@@ -448,3 +519,144 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_rope_impl<false>(ctx, dst);
|
||||
}
|
||||
|
||||
template <bool forward, bool has_ff>
|
||||
static __global__ void k_rope_cache(int nelem, int ne0, float * dst, const int * pos, const float * freq_factors,
|
||||
float theta_scale, float freq_scale, rope_corr_dims corr_dims, float ext_factor, float attn_factor) {
|
||||
|
||||
int i = 2*(blockIdx.x*blockDim.x + threadIdx.x);
|
||||
if (i >= nelem) {
|
||||
return;
|
||||
}
|
||||
int i2 = i / ne0;
|
||||
int i0 = i % ne0;
|
||||
|
||||
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, dst[i], dst[i+1]);
|
||||
if constexpr (!forward) {
|
||||
dst[i+1] *= -1;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool forward>
|
||||
void ggml_cuda_op_rope_cache_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
int sections[4];
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
|
||||
const struct ggml_tensor * tpos = dst->src[0];
|
||||
GGML_ASSERT(tpos->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(tpos->ne[0] == dst->ne[1]);
|
||||
|
||||
GGML_ASSERT(n_dims <= dst->ne[0]);
|
||||
GGML_ASSERT(n_dims % 2 == 0);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == dst->ne[0]);
|
||||
}
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (dst->src[1] != NULL) {
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[1]->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) dst->src[1]->data;
|
||||
}
|
||||
|
||||
const int * pos = (const int *) dst->src[0]->data;
|
||||
|
||||
if (dst->src[1]!= nullptr) {
|
||||
freq_factors = (const float *) dst->src[1]->data;
|
||||
}
|
||||
|
||||
int nelem = ggml_nelements(dst);
|
||||
int nblocks = (nelem + 2*CUDA_ROPE_BLOCK_SIZE - 1)/(2*CUDA_ROPE_BLOCK_SIZE);
|
||||
|
||||
if (freq_factors) {
|
||||
k_rope_cache<true, true ><<<nblocks, CUDA_ROPE_BLOCK_SIZE, 0, ctx.stream()>>>(ggml_nelements(dst), dst->ne[0],
|
||||
(float *)dst->data, pos, freq_factors, theta_scale, freq_scale, corr_dims, ext_factor, attn_factor);
|
||||
} else {
|
||||
k_rope_cache<true, false><<<nblocks, CUDA_ROPE_BLOCK_SIZE, 0, ctx.stream()>>>(ggml_nelements(dst), dst->ne[0],
|
||||
(float *)dst->data, pos, freq_factors, theta_scale, freq_scale, corr_dims, ext_factor, attn_factor);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rope_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_rope_cache_impl<true>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == dst->type);
|
||||
|
||||
const int64_t ne00 = src0->ne[0]; // head dims
|
||||
const int64_t ne01 = src0->ne[1]; // num heads
|
||||
const int64_t ne02 = src0->ne[2]; // num heads
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
|
||||
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((const int32_t *) src1->op_params)[1];
|
||||
const int mode = ((const int32_t *) src1->op_params)[2];
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne00/2);
|
||||
}
|
||||
|
||||
// compute
|
||||
if (is_neox) {
|
||||
//printf("Using neox\n");
|
||||
rope_neox_fast_cuda(
|
||||
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, stream);
|
||||
} else if (is_mrope && !is_vision) {
|
||||
rope_multi_fast_cuda(
|
||||
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
|
||||
} else if (is_vision) {
|
||||
rope_vision_fast_cuda(
|
||||
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
|
||||
} else {
|
||||
printf("Using norm\n");
|
||||
rope_norm_fast_cuda(
|
||||
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,3 +5,7 @@
|
||||
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rope_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -3470,7 +3470,7 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
|
||||
|
||||
auto rope_cache = ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
ggml_set_input(rope_cache);
|
||||
//ggml_set_input(rope_cache);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
Reference in New Issue
Block a user