mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 02:41:47 +00:00
Fusing also for iqk/trellis/repacked quants
This commit is contained in:
@@ -108,8 +108,9 @@ __device__ void iqk_mul_mat_vec_q(
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
|
||||
__device__ void iqk_fused_mul_mat_vec_q(
|
||||
__device__ void iqk_fused_mul_mat_vec_q_kernel(
|
||||
const void * __restrict__ vup, const void * __restrict__ vgate, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const float * __restrict__ bias_u, const float * __restrict__ bias_g,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
ggml_unary_op unary_op) {
|
||||
|
||||
@@ -156,7 +157,7 @@ __device__ void iqk_fused_mul_mat_vec_q(
|
||||
vec_dot_q_cuda((const void *)((const char *)vup + row0*row_size),
|
||||
&y[j*blocks_per_col_y + kby], kbx, kqs, tmp_u[j]);
|
||||
vec_dot_q_cuda((const void *)((const char *)vgate + row0*row_size),
|
||||
&y[j*blocks_per_col_y + kby], kbx, kqs, tmp_u[j]);
|
||||
&y[j*blocks_per_col_y + kby], kbx, kqs, tmp_g[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -199,12 +200,21 @@ __device__ void iqk_fused_mul_mat_vec_q(
|
||||
switch (unary_op) {
|
||||
case GGML_UNARY_OP_SILU: r = u*g/(1 + expf(-g)); break;
|
||||
case GGML_UNARY_OP_RELU: r = fmaxf(g, 0.0f) * u; break;
|
||||
// we assume that the supported ops have been checked by the caller
|
||||
default: {
|
||||
case GGML_UNARY_OP_GELU: {
|
||||
constexpr float GELU_COEF_A = 0.044715f;
|
||||
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
r = 0.5f*g*u*(1.0f + tanhf(SQRT_2_OVER_PI*g*(1.0f + GELU_COEF_A*g*g)));
|
||||
} break;
|
||||
// we assume that the supported ops have been checked by the caller
|
||||
default: {
|
||||
constexpr float alpha = 1.702f;
|
||||
constexpr float limit = 7.0f;
|
||||
g += bias_g[j*nrows_dst + row0 + threadIdx.x];
|
||||
u += bias_u[j*nrows_dst + row0 + threadIdx.x];
|
||||
g = fminf(g, limit);
|
||||
u = fmaxf(fminf(u, limit), -limit);
|
||||
r = g / (1.0f + expf(-g * alpha)) * (1.0f + u);
|
||||
} break;
|
||||
}
|
||||
dst[j*nrows_dst + row0 + threadIdx.x] = r;
|
||||
}
|
||||
@@ -229,6 +239,31 @@ __global__ void iqk_mul_mat_vec_q(
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, ncols_y, n_interleaved>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__global__ void iqk_fused_mul_mat_vec_q(
|
||||
const void * __restrict__ vx_u, const void * __restrict__ vx_g, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const char * __restrict__ ids_data, const void * __restrict__ bias_u, const void * __restrict__ bias_g, const uint64_t bias_nb1,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) {
|
||||
|
||||
int i2 = blockIdx.y;
|
||||
int i02 = *(const int *)(ids_data + i2*ids_nb0);
|
||||
if (i02 < 0) return;
|
||||
const char * cx_u = (const char *)vx_u + i02*nb02;
|
||||
const char * cx_g = (const char *)vx_g + i02*nb02;
|
||||
const char * cy = (const char *)vy + i2*nb12;
|
||||
const float * cx_u_b = bias_u ? (const float *)((const char *)bias_u + i02*bias_nb1) : nullptr;
|
||||
const float * cx_g_b = bias_g ? (const float *)((const char *)bias_g + i02*bias_nb1) : nullptr;
|
||||
char * cdst = (char *)dst + i2*nb2;
|
||||
iqk_fused_mul_mat_vec_q_kernel<type, vdr, vec_dot_q_cuda, ncols_y, n_interleaved>(
|
||||
cx_u, cx_g, cy, (float *)cdst, cx_u_b, cx_g_b,
|
||||
ncols_x, nrows_x, nrows_y, nrows_dst, row_size, unary_op);
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int n_interleaved = 1>
|
||||
void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
|
||||
@@ -270,6 +305,74 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
|
||||
const int64_t row_size = ggml_row_size(type, args.ncols_x);
|
||||
|
||||
//const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
|
||||
//const void * __restrict__ bias_u, const void * __restrict__ bias_g, const uint64_t bias_nb1,
|
||||
//const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
//const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) {
|
||||
|
||||
if (args.vx_u && args.vx_g && args.ids_data && args.unary_op != GGML_UNARY_OP_COUNT) {
|
||||
switch (args.ncols_y) {
|
||||
case 1:
|
||||
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
|
||||
args.vx_u, args.vx_g, args.vy, args.dst,
|
||||
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
|
||||
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 2:
|
||||
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
|
||||
args.vx_u, args.vx_g, args.vy, args.dst,
|
||||
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
|
||||
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 3:
|
||||
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
|
||||
args.vx_u, args.vx_g, args.vy, args.dst,
|
||||
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
|
||||
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 4:
|
||||
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
|
||||
args.vx_u, args.vx_g, args.vy, args.dst,
|
||||
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
|
||||
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 5:
|
||||
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
|
||||
args.vx_u, args.vx_g, args.vy, args.dst,
|
||||
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
|
||||
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 6:
|
||||
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
|
||||
args.vx_u, args.vx_g, args.vy, args.dst,
|
||||
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
|
||||
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 7:
|
||||
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
|
||||
args.vx_u, args.vx_g, args.vy, args.dst,
|
||||
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
|
||||
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 8:
|
||||
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
|
||||
args.vx_u, args.vx_g, args.vy, args.dst,
|
||||
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
|
||||
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (args.ncols_y) {
|
||||
case 1:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
@@ -299,6 +402,7 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,
|
||||
|
||||
@@ -231,12 +231,12 @@ static __device__ void fused_mul_mat_vec_q(
|
||||
switch (unary_op) {
|
||||
case GGML_UNARY_OP_SILU: r = u*g/(1 + expf(-g)); break;
|
||||
case GGML_UNARY_OP_RELU: r = fmaxf(g, 0.0f) * u; break;
|
||||
// we assume that the supported ops have been checked by the caller
|
||||
case GGML_UNARY_OP_GELU: {
|
||||
constexpr float GELU_COEF_A = 0.044715f;
|
||||
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
r = 0.5f*g*u*(1.0f + tanhf(SQRT_2_OVER_PI*g*(1.0f + GELU_COEF_A*g*g)));
|
||||
} break;
|
||||
// we assume that the supported ops have been checked by the caller
|
||||
default: {
|
||||
constexpr float alpha = 1.702f;
|
||||
constexpr float limit = 7.0f;
|
||||
|
||||
Reference in New Issue
Block a user