Biased mmvq: minor optimization (#880)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-31 14:21:18 +02:00
committed by GitHub
parent a3bd0158f7
commit cfb840379f
2 changed files with 46 additions and 7 deletions

View File

@@ -112,6 +112,10 @@ static __device__ void mul_mat_vec_q(
}
}
float local_bias[rows_per_cuda_block] = { 0.0f };
if (bias && threadIdx.y == 0 && threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) {
local_bias[threadIdx.x] = bias[row0 + threadIdx.x];
}
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
if (threadIdx.y > 0) {
#pragma unroll
@@ -140,7 +144,7 @@ static __device__ void mul_mat_vec_q(
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = bias ? tmp[j][threadIdx.x] + bias[j*nrows_dst + row0 + threadIdx.x] : tmp[j][threadIdx.x];
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x] + local_bias[threadIdx.x];
}
}
}
@@ -176,6 +180,14 @@ static __device__ void fused_mul_mat_vec_q(
// partial sum for each thread
float tmp_u[ncols_y][rows_per_cuda_block] = {0.0f};
float tmp_g[ncols_y][rows_per_cuda_block] = {0.0f};
float local_bias_u[rows_per_cuda_block] = { 0.0f };
float local_bias_g[rows_per_cuda_block] = { 0.0f };
if (bias_u && threadIdx.y == 0 && threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) {
local_bias_u[threadIdx.x] = bias_u[row0 + threadIdx.x];
}
if (bias_g && threadIdx.y == 0 && threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) {
local_bias_g[threadIdx.x] = bias_g[row0 + threadIdx.x];
}
const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -242,8 +254,8 @@ static __device__ void fused_mul_mat_vec_q(
default: {
constexpr float alpha = 1.702f;
constexpr float limit = 7.0f;
g += bias_g[j*nrows_dst + row0 + threadIdx.x];
u += bias_u[j*nrows_dst + row0 + threadIdx.x];
g += local_bias_g[threadIdx.x];
u += local_bias_u[threadIdx.x];
g = fminf(g, limit);
u = fmaxf(fminf(u, limit), -limit);
r = g / (1.0f + expf(-g * alpha)) * (1.0f + u);