cuda: faster MMQ for iq2_ks, iq2_k, iq2_k_r4

This commit is contained in:
Iwan Kawrakow
2025-07-08 16:38:43 +03:00
parent 4c0b660266
commit 75cc3d08e8
2 changed files with 61 additions and 57 deletions

View File

@@ -2492,12 +2492,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
const static int minus[2] = {0x1f1f1f1f, 0x1a1a1a1a};
const int kqsx = threadIdx.x%16;
auto values = iq2nl_values;
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
const int * i32 = (const int *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) {
int i = i0 + 2*threadIdx.y + threadIdx.x/16;
@@ -2511,26 +2511,26 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
uint16_t extra = bxi->extra >> 4*(kqsx/8);
int q2 = get_int_b2(bxi->qs, kqsx);
aux32[0] = ((q2 >> 0) & 0x03030303) | (((extra << 2) & 4) * 0x01010101);
aux32[1] = ((q2 >> 2) & 0x03030303) | (((extra << 1) & 4) * 0x01010101);
aux32[2] = ((q2 >> 4) & 0x03030303) | (((extra >> 0) & 4) * 0x01010101);
aux32[3] = ((q2 >> 6) & 0x03030303) | (((extra >> 1) & 4) * 0x01010101);
const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]);
const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]);
const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]);
const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]);
aux32[0] = ((q2 >> 0) & 0x03030303);
aux32[1] = ((q2 >> 2) & 0x03030303);
aux32[2] = ((q2 >> 4) & 0x03030303);
aux32[3] = ((q2 >> 6) & 0x03030303);
// 0, 16, 32, 48 -> 0, 18, 32, 48 -> -31, -13, 1, 17
#pragma unroll
for (int j = 0; j < 4; ++j) {
auto mask = __vcmpeq4(aux32[j], 0x01010101);
aux32[j] = __vadd4(aux32[j] << 4, mask & 0x02020202);
}
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = *(const int *)&val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = *(const int *)&val1;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = *(const int *)&val2;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = *(const int *)&val3;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = __vsubss4(i32[0], minus[(extra >> 0) & 1]);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = __vsubss4(i32[1], minus[(extra >> 1) & 1]);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = __vsubss4(i32[2], minus[(extra >> 2) & 1]);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = __vsubss4(i32[3], minus[(extra >> 3) & 1]);
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = *(const int *)&val0;
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = *(const int *)&val1;
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = *(const int *)&val2;
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = *(const int *)&val3;
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = __vsubss4(i32[0], minus[(extra >> 0) & 1]);
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = __vsubss4(i32[1], minus[(extra >> 1) & 1]);
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = __vsubss4(i32[2], minus[(extra >> 2) & 1]);
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = __vsubss4(i32[3], minus[(extra >> 3) & 1]);
#endif // INT8_MMA_AVAILABLE
}
@@ -2570,13 +2570,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
const static int minus[2] = {0x1f1f1f1f, 0x1a1a1a1a};
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;
auto values = iq2nl_values;
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
const int * i32 = (const int *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
@@ -2595,28 +2595,30 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
for (int l = 0; l < qstep/4; ++l) {
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
aux32[0] = ((ql >> 0) & 0x03030303) | (((extra << 2) & 4) * 0x01010101);
aux32[1] = ((ql >> 2) & 0x03030303) | (((extra << 0) & 4) * 0x01010101);
aux32[2] = ((ql >> 4) & 0x03030303) | (((extra >> 2) & 4) * 0x01010101);
aux32[3] = ((ql >> 6) & 0x03030303) | (((extra >> 4) & 4) * 0x01010101);
extra >>= 8;
const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]);
const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]);
const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]);
const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]);
aux32[0] = ((ql >> 0) & 0x03030303);
aux32[1] = ((ql >> 2) & 0x03030303);
aux32[2] = ((ql >> 4) & 0x03030303);
aux32[3] = ((ql >> 6) & 0x03030303);
// 0, 16, 32, 48 -> 0, 18, 32, 48 -> -31, -13, 1, 17
#pragma unroll
for (int j = 0; j < 4; ++j) {
auto mask = __vcmpeq4(aux32[j], 0x01010101);
aux32[j] = __vadd4(aux32[j] << 4, mask & 0x02020202);
}
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = *(const int *)&val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = *(const int *)&val1;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = *(const int *)&val2;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = *(const int *)&val3;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = __vsubss4(i32[0], minus[(extra >> 0) & 1]);
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = __vsubss4(i32[1], minus[(extra >> 2) & 1]);
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = __vsubss4(i32[2], minus[(extra >> 4) & 1]);
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = __vsubss4(i32[3], minus[(extra >> 6) & 1]);
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = *(const int *)&val0;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = *(const int *)&val1;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = *(const int *)&val2;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = *(const int *)&val3;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = __vsubss4(i32[0], minus[(extra >> 0) & 1]);
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = __vsubss4(i32[1], minus[(extra >> 2) & 1]);
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = __vsubss4(i32[2], minus[(extra >> 4) & 1]);
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = __vsubss4(i32[3], minus[(extra >> 6) & 1]);
#endif // INT8_MMA_AVAILABLE
extra >>= 8;
}
#ifdef INT8_MMA_AVAILABLE

View File

@@ -14,10 +14,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
const static int minus[2] = {0x1f1f1f1f, 0x1a1a1a1a};
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
const int * i32 = (const int *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
@@ -35,29 +37,29 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
#pragma unroll
for (int l = 0; l < 2; ++l) {
auto values_l = iq2nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 2);
auto sub = minus[(bxi->extra[ir+4*l] >> kqsx) & 1];
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
aux32[0] = (ql >> 0) & 0x03030303;
aux32[1] = (ql >> 2) & 0x03030303;
aux32[2] = (ql >> 4) & 0x03030303;
aux32[3] = (ql >> 6) & 0x03030303;
const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]);
const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]);
const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]);
const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]);
#pragma unroll
for (int j = 0; j < 4; ++j) {
auto mask = __vcmpeq4(aux32[j], 0x01010101);
aux32[j] = __vadd4(aux32[j] << 4, mask & 0x02020202);
}
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val1;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val2;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = __vsubss4(i32[0], sub);
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = __vsubss4(i32[1], sub);
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = __vsubss4(i32[2], sub);
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = __vsubss4(i32[3], sub);
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val1;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val2;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = __vsubss4(i32[0], sub);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = __vsubss4(i32[1], sub);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = __vsubss4(i32[2], sub);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = __vsubss4(i32[3], sub);
#endif // INT8_MMA_AVAILABLE
}