mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Minor CUDA PP speed improvement (#567)
* Slightly better q8_0_q8_1 kerneel and iqk_ks tile loading * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -861,61 +861,60 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
const float * y_df = (const float *) y;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
mma_A A[ntx][WARP_SIZE/QI8_0];
|
||||
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
|
||||
mma_A A[ntx];
|
||||
float dA[ntx][mma_C::ne/2];
|
||||
|
||||
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
const int k0 = k00 + k01;
|
||||
mma_B B;
|
||||
float dB[mma_C::ne/2];
|
||||
B.load(y_qs + k01, MMQ_TILE_Y_K);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
|
||||
const int j = mma_C::get_j(l);
|
||||
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||
} else {
|
||||
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
mma_B B;
|
||||
float dB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
||||
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
|
||||
}
|
||||
mma_C C;
|
||||
C.mma_K8(A[n], B);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = ntx*mma_C::J; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||
} else {
|
||||
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C C;
|
||||
C.mma_K8(A[n][k01/QI8_0], B);
|
||||
|
||||
#pragma unroll
|
||||
C.mma_K8(A[n], B);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2701,6 +2700,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
}
|
||||
|
||||
//template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
|
||||
// const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
//
|
||||
//#ifdef INT8_MMA_AVAILABLE
|
||||
// int * x_qs = (int *) x_tile;
|
||||
// float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
//#else
|
||||
// constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
||||
// int * x_qs = (int *) x_tile;
|
||||
// float * x_df = (float *) (x_qs + txs.qs);
|
||||
//#endif // INT8_MMA_AVAILABLE
|
||||
//
|
||||
// const int kbx = 0; // threadIdx.x / QI4_XS
|
||||
// const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
|
||||
//
|
||||
//#pragma unroll
|
||||
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
||||
// int i = i0 + threadIdx.y;
|
||||
//
|
||||
// if (need_check) {
|
||||
// i = min(i, i_max);
|
||||
// }
|
||||
//
|
||||
// const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
|
||||
//
|
||||
// auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
|
||||
// const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
||||
// const int2 v = get_int_from_table_16(aux_q4, values);
|
||||
// const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
||||
//#ifdef INT8_MMA_AVAILABLE
|
||||
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
||||
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
||||
//#else
|
||||
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
||||
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
||||
//#endif // INT8_MMA_AVAILABLE
|
||||
// }
|
||||
//
|
||||
//#pragma unroll
|
||||
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
||||
// int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
|
||||
//
|
||||
// if (need_check) {
|
||||
// i = min(i, i_max);
|
||||
// }
|
||||
//
|
||||
// const float * dptr = (const float *)(x + i*stride);
|
||||
// const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
|
||||
// const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
|
||||
//
|
||||
//#ifdef INT8_MMA_AVAILABLE
|
||||
// x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
|
||||
//#else
|
||||
// x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
|
||||
//#endif // INT8_MMA_AVAILABLE
|
||||
// }
|
||||
//}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
@@ -2713,35 +2770,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int kbx = 0; // threadIdx.x / QI4_XS
|
||||
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
|
||||
const int kqsx = threadIdx.x / 4;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
||||
int i = i0 + threadIdx.y;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
|
||||
|
||||
auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
|
||||
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
||||
const int2 v = get_int_from_table_16(aux_q4, values);
|
||||
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
||||
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
|
||||
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
@@ -2749,16 +2782,31 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
|
||||
const float * dptr = (const float *)(x + i*stride);
|
||||
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
|
||||
const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
|
||||
const int ls = (bxi->scales[kqsx] & 254) - 127;
|
||||
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const int aux_q4 = get_int_b4(bxi->qs, 4*kqsx+j);
|
||||
const int2 v = get_int_from_table_16(aux_q4, values);
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
|
||||
#else
|
||||
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ls;
|
||||
#else
|
||||
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ls;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks_r4(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user