CUDA non-contiguous RoPE (#66)

In this way we can avoid the Q, K, V copies being made
after multiplication with the QKV tensor in, e.g., Phi-3.5-mini.
This results in a 6-7% speedup of PP-512(Phi-3.5-mini)
on CUDA (RTX-4080)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2024-09-28 17:41:21 +03:00
committed by GitHub
parent 947a348990
commit 54b1c97878
3 changed files with 231 additions and 32 deletions

View File

@@ -2924,9 +2924,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_CAP_MAX:
return true;
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
return true;
//case GGML_OP_ROPE:
// return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:

View File

@@ -68,6 +68,49 @@ static __global__ void rope_norm(
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_ff>
static __global__ void rope_norm_nc(
const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int j2 = row/ne1;
const int j1 = row%ne1;
const T * xx = x + j1*nb1 + j2*nb2;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;
dst[i + 0] = xx[i0 + 0];
dst[i + 1] = xx[i0 + 1];
return;
}
const int i = row*ne0 + i0;
const int i2 = row/p_delta_rows;
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = xx[i0 + 0];
const float x1 = xx[i0 + 1];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_ff>
static __global__ void rope_neox(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
@@ -108,6 +151,49 @@ static __global__ void rope_neox(
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_ff>
static __global__ void rope_neox_nc(
const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int j2 = row/ne1;
const int j1 = row%ne1;
const T * xx = x + j1*nb1 + j2*nb2;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;
dst[i + 0] = xx[i0 + 0];
dst[i + 1] = xx[i0 + 1];
return;
}
const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows;
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = xx[i0/2 + 0];
const float x1 = xx[i0/2 + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
template<typename T>
static void rope_norm_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
@@ -132,6 +218,30 @@ static void rope_norm_cuda(
}
}
template<typename T>
static void rope_norm_nc_cuda(
const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_norm_nc<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
} else {
rope_norm_nc<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
}
}
template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
@@ -156,6 +266,30 @@ static void rope_neox_cuda(
}
}
template<typename T>
static void rope_neox_nc_cuda(
const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_neox_nc<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
} else {
rope_neox_nc<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
}
}
static void rope_norm_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@@ -170,6 +304,20 @@ static void rope_norm_cuda_f32(
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_norm_cuda_nc_f16(
const half * x, half * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_norm_nc_cuda<half>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_norm_cuda_nc_f32(
const float * x, float * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_norm_nc_cuda<float>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@@ -185,6 +333,20 @@ static void rope_neox_cuda_f32(
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_nc_f16(
const half * x, half * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_neox_nc_cuda<half>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_nc_f32(
const float * x, float * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_neox_nc_cuda<float>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -196,7 +358,8 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
const bool is_contiguous = ggml_is_contiguous(src0);
//GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
@@ -239,33 +402,69 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
// compute
if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
if (is_contiguous) {
if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ABORT("fatal error");
}
} else {
GGML_ABORT("fatal error");
if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_norm_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ABORT("fatal error");
}
}
} else {
if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_norm_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_nc_f32(
(const float *)src0_d, (float *)dst_d, ne00, ne01, src0->nb[1]/sizeof(float), src0->nb[2]/sizeof(float),
n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_nc_f16(
(const half *)src0_d, (half *)dst_d, ne00, ne01, src0->nb[1]/sizeof(half), src0->nb[2]/sizeof(half),
n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ABORT("fatal error");
}
} else {
GGML_ABORT("fatal error");
if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda_nc_f32(
(const float *)src0_d, (float *)dst_d, ne00, ne01, src0->nb[1]/sizeof(float), src0->nb[2]/sizeof(float),
n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_norm_cuda_nc_f16(
(const half *)src0_d, (half *)dst_d, ne00, ne01, src0->nb[1]/sizeof(half), src0->nb[2]/sizeof(half),
n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ABORT("fatal error");
}
}
}
}

View File

@@ -10919,23 +10919,22 @@ struct llm_build_context {
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
cb(cur, "wqkv", il);
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd));
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd));
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
}
else {
Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow