diff --git a/ggml/src/ggml-cuda/iqk_mmvq_templates.cuh b/ggml/src/ggml-cuda/iqk_mmvq_templates.cuh index 466b1e9c..95a80d54 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq_templates.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq_templates.cuh @@ -104,7 +104,7 @@ __device__ void iqk_mul_mat_vec_q_kerne( } if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { - dst[j*nrows_dst + row0 + threadIdx.x] = bias ? tmp[j][threadIdx.x] + bias[j*nrows_dst + row0 + threadIdx.x] : tmp[j][threadIdx.x]; + dst[j*nrows_dst + row0 + threadIdx.x] = bias ? tmp[j][threadIdx.x] + bias[row0 + threadIdx.x] : tmp[j][threadIdx.x]; } } } @@ -211,8 +211,8 @@ __device__ void iqk_fused_mul_mat_vec_q_kernel( default: { constexpr float alpha = 1.702f; constexpr float limit = 7.0f; - g += bias_g[j*nrows_dst + row0 + threadIdx.x]; - u += bias_u[j*nrows_dst + row0 + threadIdx.x]; + g += bias_g[row0 + threadIdx.x]; + u += bias_u[row0 + threadIdx.x]; g = fminf(g, limit); u = fmaxf(fminf(u, limit), -limit); r = g / (1.0f + expf(-g * alpha)) * (1.0f + u);