Minor CUDA PP speed improvement (#567)

* Slightly better q8_0_q8_1 kerneel and iqk_ks tile loading

* Minor

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-07-02 09:11:33 +02:00
committed by GitHub
parent 8a71405f5f
commit 6215d9315c

View File

@@ -861,61 +861,60 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const float * y_df = (const float *) y;
const half2 * y_ds = (const half2 *) y;
mma_A A[ntx][WARP_SIZE/QI8_0];
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
mma_A A[ntx];
float dA[ntx][mma_C::ne/2];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;
A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
}
#pragma unroll
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;
mma_B B;
float dB[mma_C::ne/2];
B.load(y_qs + k01, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;
dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
const int j = mma_C::get_j(l);
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
} else {
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
mma_B B;
float dB[mma_C::ne/2];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
}
mma_C C;
C.mma_K8(A[n], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
}
}
#pragma unroll
for (int j0 = ntx*mma_C::J; j0 < mmq_x; j0 += ntx*mma_C::J) {
B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
#pragma unroll
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
} else {
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
#pragma unroll
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
C.mma_K8(A[n][k01/QI8_0], B);
#pragma unroll
C.mma_K8(A[n], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
}
}
}
@@ -2701,6 +2700,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}
//template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
// const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
//
//#ifdef INT8_MMA_AVAILABLE
// int * x_qs = (int *) x_tile;
// float * x_df = (float *) (x_qs + WARP_SIZE*2);
//#else
// constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
// int * x_qs = (int *) x_tile;
// float * x_df = (float *) (x_qs + txs.qs);
//#endif // INT8_MMA_AVAILABLE
//
// const int kbx = 0; // threadIdx.x / QI4_XS
// const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
//
//#pragma unroll
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
// int i = i0 + threadIdx.y;
//
// if (need_check) {
// i = min(i, i_max);
// }
//
// const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
//
// auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
// const int aux_q4 = get_int_b4(bxi->qs, kqsx);
// const int2 v = get_int_from_table_16(aux_q4, values);
// const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
//#ifdef INT8_MMA_AVAILABLE
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
//#else
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
//#endif // INT8_MMA_AVAILABLE
// }
//
//#pragma unroll
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
// int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
//
// if (need_check) {
// i = min(i, i_max);
// }
//
// const float * dptr = (const float *)(x + i*stride);
// const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
// const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
//
//#ifdef INT8_MMA_AVAILABLE
// x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
//#else
// x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
//#endif // INT8_MMA_AVAILABLE
// }
//}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@@ -2713,35 +2770,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
const int kbx = 0; // threadIdx.x / QI4_XS
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
const int kqsx = threadIdx.x / 4;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
if (need_check) {
i = min(i, i_max);
}
const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
const int2 v = get_int_from_table_16(aux_q4, values);
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
#endif // INT8_MMA_AVAILABLE
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
if (need_check) {
i = min(i, i_max);
@@ -2749,16 +2782,31 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const float * dptr = (const float *)(x + i*stride);
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
const int ls = (bxi->scales[kqsx] & 254) - 127;
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int aux_q4 = get_int_b4(bxi->qs, 4*kqsx+j);
const int2 v = get_int_from_table_16(aux_q4, values);
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
#else
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ls;
#else
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ls;
#endif // INT8_MMA_AVAILABLE
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks_r4(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {