mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Fuse Q and K RoPE (#980)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -3346,7 +3346,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_soft_cap_max(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
ggml_cuda_op_rope(ctx, dst);
|
||||
if (fusion && i + 2 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ROPE &&
|
||||
ggml_cuda_op_rope_rope(ctx, dst, cgraph->nodes[i+2])) {
|
||||
i += 2;
|
||||
}
|
||||
else if (fusion && i + 1 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_ROPE &&
|
||||
ggml_cuda_op_rope_rope(ctx, dst, cgraph->nodes[i+1])) {
|
||||
i += 1;
|
||||
} else {
|
||||
ggml_cuda_op_rope(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
ggml_cuda_op_rope_back(ctx, dst);
|
||||
|
||||
@@ -1026,3 +1026,427 @@ bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
ctx.stream());
|
||||
return true;
|
||||
}
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_rope_neox(
|
||||
const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2,
|
||||
const int s1_1, const int s2_1, const int s1_2, const int s2_2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int row_x = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
const int channel_x = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const T * x;
|
||||
T * dst;
|
||||
int ne1, s1, s2;
|
||||
if (row_x < ne1_1) {
|
||||
x = x1;
|
||||
dst = dst1;
|
||||
ne1 = ne1_1;
|
||||
s1 = s1_1;
|
||||
s2 = s2_1;
|
||||
} else {
|
||||
x = x2;
|
||||
dst = dst2;
|
||||
row_x -= ne1_1;
|
||||
ne1 = ne1_2;
|
||||
s1 = s1_2;
|
||||
s2 = s2_2;
|
||||
}
|
||||
|
||||
const int idst = (row_x + channel_x*ne1)*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float t0 = x[ix + 0];
|
||||
const float t1 = x[ix + n_dims/2];
|
||||
|
||||
dst[idst + 0] = t0*cos_theta - t1*sin_theta;
|
||||
dst[idst + n_dims/2] = t0*sin_theta + t1*cos_theta;
|
||||
}
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_rope_norm(
|
||||
const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2,
|
||||
const int s1_1, const int s2_1, const int s1_2, const int s2_2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int row_x = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
const int channel_x = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const T * x;
|
||||
T * dst;
|
||||
int ne1, s1, s2;
|
||||
if (row_x < ne1_1) {
|
||||
x = x1;
|
||||
dst = dst1;
|
||||
ne1 = ne1_1;
|
||||
s1 = s1_1;
|
||||
s2 = s2_1;
|
||||
} else {
|
||||
x = x2;
|
||||
dst = dst2;
|
||||
row_x -= ne1_1;
|
||||
ne1 = ne1_2;
|
||||
s1 = s1_2;
|
||||
s2 = s2_2;
|
||||
}
|
||||
|
||||
const int idst = (row_x + channel_x*ne1)*ne0 + i0;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + 0] = x[ix + 0];
|
||||
dst[idst + 1] = x[ix + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float t0 = x[ix + 0];
|
||||
const float t1 = x[ix + 1];
|
||||
|
||||
dst[idst + 0] = t0*cos_theta - t1*sin_theta;
|
||||
dst[idst + 1] = t0*sin_theta + t1*cos_theta;
|
||||
}
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_rope_multi(
|
||||
const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2,
|
||||
const int ne2, const int s1_1, const int s2_1, const int s1_2, const int s2_2,
|
||||
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int row_x = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
const int channel_x = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const T * x;
|
||||
T * dst;
|
||||
int ne1, s1, s2;
|
||||
if (row_x < ne1_1) {
|
||||
x = x1;
|
||||
dst = dst1;
|
||||
ne1 = ne1_1;
|
||||
s1 = s1_1;
|
||||
s2 = s2_1;
|
||||
} else {
|
||||
x = x2;
|
||||
dst = dst2;
|
||||
row_x -= ne1_1;
|
||||
ne1 = ne1_2;
|
||||
s1 = s1_2;
|
||||
s2 = s2_2;
|
||||
}
|
||||
|
||||
const int idst = (row_x + channel_x*ne1)*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float t0 = x[ix + 0];
|
||||
const float t1 = x[ix + n_dims/2];
|
||||
|
||||
dst[idst + 0] = t0*cos_theta - t1*sin_theta;
|
||||
dst[idst + n_dims/2] = t0*sin_theta + t1*cos_theta;
|
||||
}
|
||||
//rope_rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
// x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor,
|
||||
// attn_factor, corr_dims, theta_scale, freq_factors);
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_rope_vision(
|
||||
const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2,
|
||||
const int ne2, const int s1_1, const int s2_1, const int s1_2, const int s2_2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int row_x = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
const int channel_x = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const T * x;
|
||||
T * dst;
|
||||
int ne1, s1, s2;
|
||||
if (row_x < ne1_1) {
|
||||
x = x1;
|
||||
dst = dst1;
|
||||
ne1 = ne1_1;
|
||||
s1 = s1_1;
|
||||
s2 = s2_1;
|
||||
} else {
|
||||
x = x2;
|
||||
dst = dst2;
|
||||
row_x -= ne1_1;
|
||||
ne1 = ne1_2;
|
||||
s1 = s1_2;
|
||||
s2 = s2_2;
|
||||
}
|
||||
|
||||
const int idst = (row_x + channel_x*ne1)*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < sections.v[0]) {
|
||||
const int p = sector;
|
||||
theta_base = pos[channel_x]*powf(theta_scale, p);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
const int p = sector - sections.v[0];
|
||||
theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float t0 = x[ix + 0];
|
||||
const float t1 = x[ix + n_dims];
|
||||
|
||||
dst[idst + 0] = t0*cos_theta - t1*sin_theta;
|
||||
dst[idst + n_dims] = t0*sin_theta + t1*cos_theta;
|
||||
}
|
||||
|
||||
template<bool forward, typename T>
|
||||
static void rope_rope_cuda(int kernel,
|
||||
const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2,
|
||||
const int s1_1, const int s2_1, const int se1_2, const int se2_2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_mrope, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nr, n_blocks_x, ne1_1 + ne1_2);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
switch (kernel) {
|
||||
case 0:
|
||||
rope_rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors);
|
||||
break;
|
||||
case 1:
|
||||
rope_rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_mrope);
|
||||
break;
|
||||
case 2:
|
||||
rope_rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||
break;
|
||||
case 3:
|
||||
rope_rope_norm<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors);
|
||||
break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
} else {
|
||||
switch (kernel) {
|
||||
case 0:
|
||||
rope_rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors);
|
||||
break;
|
||||
case 1:
|
||||
rope_rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_mrope);
|
||||
break;
|
||||
case 2:
|
||||
rope_rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||
break;
|
||||
case 3:
|
||||
rope_rope_norm<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors);
|
||||
break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool forward>
|
||||
bool ggml_cuda_op_rope_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) {
|
||||
if (dst1->src[1] != dst2->src[1]) return false;
|
||||
if (dst1->src[2] != dst2->src[2]) return false;
|
||||
//if (dst1->ne[2] > 1 || dst1->ne[3] > 1 || dst1->ne[0] != dst2->ne[0] || dst1->ne[2] != dst2->ne[2] || dst1->ne[3] != dst2->ne[3]) return false;
|
||||
if (dst1->ne[3] > 1 || dst1->ne[0] != dst2->ne[0] || dst1->ne[2] != dst2->ne[2] || dst1->ne[3] != dst2->ne[3]) return false;
|
||||
if (memcmp(dst1->op_params, dst2->op_params, 15*sizeof(int))) return false;
|
||||
if (dst1->src[0]->type != dst2->src[0]->type) return false;
|
||||
if (dst1->type != dst2->type) return false;
|
||||
if (dst1->src[0]->type != GGML_TYPE_F32 && dst1->src[0]->type != GGML_TYPE_F16) return false;
|
||||
if (dst1->src[0]->type != dst1->type) return false;
|
||||
|
||||
const int64_t ne00 = dst1->src[0]->ne[0];
|
||||
const int64_t ne01_1 = dst1->src[0]->ne[1];
|
||||
const int64_t ne01_2 = dst2->src[0]->ne[1];
|
||||
const int64_t ne02 = dst1->src[0]->ne[2];
|
||||
|
||||
const size_t s01_1 = dst1->src[0]->nb[1] / ggml_type_size(dst1->src[0]->type);
|
||||
const size_t s02_1 = dst1->src[0]->nb[2] / ggml_type_size(dst1->src[0]->type);
|
||||
const size_t s01_2 = dst2->src[0]->nb[1] / ggml_type_size(dst2->src[0]->type);
|
||||
const size_t s02_2 = dst2->src[0]->nb[2] / ggml_type_size(dst2->src[0]->type);
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst1->op_params)[1];
|
||||
const int mode = ((int32_t *) dst1->op_params)[2];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst1->op_params)[4];
|
||||
mrope_sections sections;
|
||||
|
||||
// RoPE alteration for extended context
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst1->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst1->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst1->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst1->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst1->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst1->op_params + 10, sizeof(float));
|
||||
memcpy(§ions.v, (int32_t *) dst1->op_params + 11, sizeof(int)*4);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne00/2);
|
||||
}
|
||||
|
||||
const int32_t * pos = (const int32_t *) dst1->src[1]->data;
|
||||
|
||||
const float * freq_factors = nullptr;
|
||||
if (dst1->src[2] != nullptr) {
|
||||
freq_factors = (const float *) dst1->src[2]->data;
|
||||
}
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
auto type = dst1->type;
|
||||
int kernel = is_neox ? 0 : is_mrope && !is_vision ? 1 : is_vision ? 2 : 3;
|
||||
|
||||
// compute
|
||||
if (type == GGML_TYPE_F32) {
|
||||
rope_rope_cuda<forward>(kernel,
|
||||
(const float *)dst1->src[0]->data, (const float *)dst2->src[0]->data, (float *)dst1->data, (float *)dst2->data,
|
||||
ne00, ne01_1, ne01_2, s01_1, s02_1, s01_2, s02_2, n_dims, ne02, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, ctx.stream());
|
||||
} else {
|
||||
rope_rope_cuda<forward>(kernel,
|
||||
(const half *)dst1->src[0]->data, (const half *)dst2->src[0]->data, (half *)dst1->data, (half *)dst2->data,
|
||||
ne00, ne01_1, ne01_2, s01_1, s02_1, s01_2, s02_2, n_dims, ne02, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, ctx.stream());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ggml_cuda_op_rope_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) {
|
||||
return ggml_cuda_op_rope_rope_impl<true>(ctx, dst1, dst2);
|
||||
}
|
||||
|
||||
@@ -13,3 +13,5 @@ void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);
|
||||
|
||||
bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);
|
||||
|
||||
bool ggml_cuda_op_rope_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);
|
||||
|
||||
Reference in New Issue
Block a user