Fusing also for iqk/trellis/repacked quants

This commit is contained in:
Iwan Kawrakow
2025-10-24 15:05:18 +03:00
parent 3da71dcda2
commit 196e73588c
2 changed files with 109 additions and 5 deletions

View File

@@ -108,8 +108,9 @@ __device__ void iqk_mul_mat_vec_q(
}
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
__device__ void iqk_fused_mul_mat_vec_q(
__device__ void iqk_fused_mul_mat_vec_q_kernel(
const void * __restrict__ vup, const void * __restrict__ vgate, const void * __restrict__ vy, float * __restrict__ dst,
const float * __restrict__ bias_u, const float * __restrict__ bias_g,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
ggml_unary_op unary_op) {
@@ -156,7 +157,7 @@ __device__ void iqk_fused_mul_mat_vec_q(
vec_dot_q_cuda((const void *)((const char *)vup + row0*row_size),
&y[j*blocks_per_col_y + kby], kbx, kqs, tmp_u[j]);
vec_dot_q_cuda((const void *)((const char *)vgate + row0*row_size),
&y[j*blocks_per_col_y + kby], kbx, kqs, tmp_u[j]);
&y[j*blocks_per_col_y + kby], kbx, kqs, tmp_g[j]);
}
}
}
@@ -199,12 +200,21 @@ __device__ void iqk_fused_mul_mat_vec_q(
switch (unary_op) {
case GGML_UNARY_OP_SILU: r = u*g/(1 + expf(-g)); break;
case GGML_UNARY_OP_RELU: r = fmaxf(g, 0.0f) * u; break;
// we assume that the supported ops have been checked by the caller
default: {
case GGML_UNARY_OP_GELU: {
constexpr float GELU_COEF_A = 0.044715f;
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
r = 0.5f*g*u*(1.0f + tanhf(SQRT_2_OVER_PI*g*(1.0f + GELU_COEF_A*g*g)));
} break;
// we assume that the supported ops have been checked by the caller
default: {
constexpr float alpha = 1.702f;
constexpr float limit = 7.0f;
g += bias_g[j*nrows_dst + row0 + threadIdx.x];
u += bias_u[j*nrows_dst + row0 + threadIdx.x];
g = fminf(g, limit);
u = fmaxf(fminf(u, limit), -limit);
r = g / (1.0f + expf(-g * alpha)) * (1.0f + u);
} break;
}
dst[j*nrows_dst + row0 + threadIdx.x] = r;
}
@@ -229,6 +239,31 @@ __global__ void iqk_mul_mat_vec_q(
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, ncols_y, n_interleaved>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
}
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__global__ void iqk_fused_mul_mat_vec_q(
const void * __restrict__ vx_u, const void * __restrict__ vx_g, const void * __restrict__ vy, float * __restrict__ dst,
const char * __restrict__ ids_data, const void * __restrict__ bias_u, const void * __restrict__ bias_g, const uint64_t bias_nb1,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) {
int i2 = blockIdx.y;
int i02 = *(const int *)(ids_data + i2*ids_nb0);
if (i02 < 0) return;
const char * cx_u = (const char *)vx_u + i02*nb02;
const char * cx_g = (const char *)vx_g + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
const float * cx_u_b = bias_u ? (const float *)((const char *)bias_u + i02*bias_nb1) : nullptr;
const float * cx_g_b = bias_g ? (const float *)((const char *)bias_g + i02*bias_nb1) : nullptr;
char * cdst = (char *)dst + i2*nb2;
iqk_fused_mul_mat_vec_q_kernel<type, vdr, vec_dot_q_cuda, ncols_y, n_interleaved>(
cx_u, cx_g, cy, (float *)cdst, cx_u_b, cx_g_b,
ncols_x, nrows_x, nrows_y, nrows_dst, row_size, unary_op);
}
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int n_interleaved = 1>
void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
@@ -270,6 +305,74 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
const int64_t row_size = ggml_row_size(type, args.ncols_x);
//const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
//const void * __restrict__ bias_u, const void * __restrict__ bias_g, const uint64_t bias_nb1,
//const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
//const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) {
if (args.vx_u && args.vx_g && args.ids_data && args.unary_op != GGML_UNARY_OP_COUNT) {
switch (args.ncols_y) {
case 1:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
break;
case 2:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
break;
case 3:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
break;
case 4:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
break;
case 5:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
break;
case 6:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
break;
case 7:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
break;
case 8:
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vx_g, args.vy, args.dst,
args.ids_data, args.bias_u, args.bias_g, args.bias_nb1,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size,
args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
break;
default:
GGML_ASSERT(false);
break;
}
} else {
switch (args.ncols_y) {
case 1:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
@@ -299,6 +402,7 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
GGML_ASSERT(false);
break;
}
}
}
__device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,

View File

@@ -231,12 +231,12 @@ static __device__ void fused_mul_mat_vec_q(
switch (unary_op) {
case GGML_UNARY_OP_SILU: r = u*g/(1 + expf(-g)); break;
case GGML_UNARY_OP_RELU: r = fmaxf(g, 0.0f) * u; break;
// we assume that the supported ops have been checked by the caller
case GGML_UNARY_OP_GELU: {
constexpr float GELU_COEF_A = 0.044715f;
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
r = 0.5f*g*u*(1.0f + tanhf(SQRT_2_OVER_PI*g*(1.0f + GELU_COEF_A*g*g)));
} break;
// we assume that the supported ops have been checked by the caller
default: {
constexpr float alpha = 1.702f;
constexpr float limit = 7.0f;