mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
cuda: MMQ for iq5_k_r4
This commit is contained in:
@@ -142,6 +142,9 @@ void ggml_cuda_op_mul_mat_q(
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
mul_mat_q_case<GGML_TYPE_IQ4_K_R4>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_K_R4:
|
||||
mul_mat_q_case<GGML_TYPE_IQ5_K_R4>(ctx, args, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
@@ -196,6 +199,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
case GGML_TYPE_IQ3_K_R4:
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
case GGML_TYPE_IQ5_K_R4:
|
||||
mmq_supported = true;
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -93,6 +93,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
case GGML_TYPE_IQ5_K:
|
||||
case GGML_TYPE_IQ5_K_R4:
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
case GGML_TYPE_IQ5_KS_R4:
|
||||
case GGML_TYPE_IQ6_K:
|
||||
@@ -210,6 +211,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_IQ4_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0;
|
||||
@@ -264,6 +266,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_IQ4_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
@@ -4102,6 +4105,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K_R4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K_R4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);
|
||||
|
||||
@@ -32,35 +32,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
|
||||
const float d = __half2float(bxi->d[ir]);
|
||||
|
||||
//void dequantize_row_iq4_k_r4(const block_iq4_k_r4 * x, float * y, int64_t k) {
|
||||
// auto n_per_row = k/4;
|
||||
// float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row};
|
||||
// int nblock = n_per_row/QK_K;
|
||||
// for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||
// for (int k = 0; k < 4; ++k) {
|
||||
// const float d = GGML_FP16_TO_FP32(x[ibl].d[k]);
|
||||
// for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
// int is = 8*ib + k;
|
||||
// float dl1 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
|
||||
// is += 4;
|
||||
// float dl2 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
|
||||
// auto values1 = iq4k_values + (x[ibl].extra[k+0] & (1 << ib) ? 16 : 0);
|
||||
// auto values2 = iq4k_values + (x[ibl].extra[k+4] & (1 << ib) ? 16 : 0);
|
||||
|
||||
// y4[k][8*ib+i+0] = dl1 * values1[x[ibl].qs[16*ib+k+ 0] & 0xf];
|
||||
// y4[k][8*ib+i+2] = dl1 * values1[x[ibl].qs[16*ib+k+ 0] >> 4];
|
||||
// y4[k][8*ib+i+4] = dl2 * values2[x[ibl].qs[16*ib+k+ 4] & 0xf];
|
||||
// y4[k][8*ib+i+6] = dl2 * values2[x[ibl].qs[16*ib+k+ 4] >> 4];
|
||||
// y4[k][8*ib+i+1] = dl1 * values1[x[ibl].qs[16*ib+k+ 8] & 0xf];
|
||||
// y4[k][8*ib+i+3] = dl1 * values1[x[ibl].qs[16*ib+k+ 8] >> 4];
|
||||
// y4[k][8*ib+i+5] = dl2 * values2[x[ibl].qs[16*ib+k+12] & 0xf];
|
||||
// y4[k][8*ib+i+7] = dl2 * values2[x[ibl].qs[16*ib+k+12] >> 4];
|
||||
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq.cuh"
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_k_r4(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
|
||||
|
||||
uint32_t aux32[4];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
|
||||
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
int i4 = i/4;
|
||||
int ir = i%4;
|
||||
|
||||
const block_iq5_k_r4 * bxi = (const block_iq5_k_r4 *)(x + 4*i4*stride) + kbx0;
|
||||
|
||||
const float d = __half2float(bxi->d[ir]);
|
||||
|
||||
int qh = get_int_b4(bxi->qh, 4*kqsx + ir);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
|
||||
auto values_l = iq5nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 5);
|
||||
|
||||
const int ql1 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 0);
|
||||
const int ql2 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 8);
|
||||
aux32[0] = ((ql1 >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010);
|
||||
aux32[1] = ((ql1 >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010);
|
||||
aux32[2] = ((ql2 >> 0) & 0x0f0f0f0f) | ((qh >> 0) & 0x10101010);
|
||||
aux32[3] = ((ql2 >> 4) & 0x0f0f0f0f) | ((qh >> 1) & 0x10101010);
|
||||
|
||||
const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]);
|
||||
const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]);
|
||||
const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]);
|
||||
const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
qh >>= 2;
|
||||
}
|
||||
|
||||
int is = 8*kqsx + ir;
|
||||
float dl1 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
|
||||
is += 4;
|
||||
float dl2 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1;
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2;
|
||||
#else
|
||||
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1;
|
||||
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_K_R4> {
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k_r4<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4);
|
||||
Reference in New Issue
Block a user