CUDA: fix TG with SER

This commit is contained in:
Iwan Kawrakow
2025-05-10 11:06:48 +03:00
parent b38112028e
commit c4e1c2c905

View File

@@ -147,11 +147,27 @@ static __global__ void mul_mat_vec_q(
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst,
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) {
int i2 = blockIdx.y;
char * cdst = (char *)dst + i2*nb2;
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
if (i02 < 0) return;
if (i02 < 0) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int rows_per_cuda_block = 1;
#else
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
const int row0 = rows_per_cuda_block*blockIdx.x;
if (threadIdx.y == 0) {
dst = (float *)cdst;
for (int j = 0; j < ncols_y; ++j) {
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = 0;
}
}
}
return;
}
const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
char * cdst = (char *)dst + i2*nb2;
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
}