mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 18:10:02 +00:00
POC per row scale: iq2_tn on Metal
This commit is contained in:
@@ -3778,21 +3778,21 @@ void kernel_mul_mv_iq2_tn_f32_impl(
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int row_size = nb*sizeof(block_iq2_tn) + 4;
|
||||
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
const int ib_row = first_row * nb;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*ne01*ne02)*row_size;
|
||||
|
||||
device const block_iq2_tn * x = (device const block_iq2_tn *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
device const char * cx = (device const char *) src0 + first_row*row_size + offset0;
|
||||
device const float * y = (device const float*) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
||||
const int step = sizeof(block_iq2_tn) * nb / 2;
|
||||
float drow[N_DST];
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
@@ -3802,6 +3802,8 @@ void kernel_mul_mv_iq2_tn_f32_impl(
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) drow[row] = *((device const float *)(cx + row*row_size));
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
float sumy = 0.f;
|
||||
@@ -3812,7 +3814,7 @@ void kernel_mul_mv_iq2_tn_f32_impl(
|
||||
yl[i+24] = y4[i+96]; sumy += yl[i+24];
|
||||
}
|
||||
|
||||
device const half * dh = &x[ib].d;
|
||||
device const block_iq2_tn * x = (device const block_iq2_tn *)(cx + 4);
|
||||
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
@@ -3829,14 +3831,12 @@ void kernel_mul_mv_iq2_tn_f32_impl(
|
||||
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
||||
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
||||
}
|
||||
float dall = dh[0];
|
||||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f +
|
||||
(acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f +
|
||||
(acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f +
|
||||
(acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy);
|
||||
sumf[row] += (acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f +
|
||||
(acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f +
|
||||
(acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f +
|
||||
(acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy;
|
||||
|
||||
qs += step;
|
||||
dh += step;
|
||||
qs += row_size/2;
|
||||
}
|
||||
|
||||
y4 += 4 * QK_K;
|
||||
@@ -3845,7 +3845,7 @@ void kernel_mul_mv_iq2_tn_f32_impl(
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = drow[row]*all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6850,16 +6850,14 @@ void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & re
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_tn(device const block_iq2_tn * xb, short il, thread type4x4 & reg) {
|
||||
const half d = xb->d;
|
||||
device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1);
|
||||
|
||||
il = (il/2)%4;
|
||||
|
||||
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
||||
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
const half dl = d * coef;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * (q[i] & mask) - d;
|
||||
reg[i/4][i%4] = coef * (q[i] & mask) - 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7472,6 +7470,32 @@ struct DequantizerIQ1TN {
|
||||
half d;
|
||||
};
|
||||
|
||||
template <typename T4x4>
|
||||
struct DequantizerIQ2TN {
|
||||
using type4x4 = T4x4;
|
||||
using Block = block_iq2_tn;
|
||||
constexpr constant static int nl = 16;
|
||||
DequantizerIQ2TN(device const char * cx, short il = 0) : il(il) {
|
||||
d = *(device const float *)cx;
|
||||
x = (device const Block *)(cx + sizeof(float));
|
||||
}
|
||||
inline void convert(thread T4x4& t) const {
|
||||
dequantize_iq2_tn(x, il, t);
|
||||
t *= d;
|
||||
}
|
||||
inline void convert(int64_t ind, thread T4x4& t) {
|
||||
dequantize_iq2_tn(x + ind/nl, ind%nl, t);
|
||||
t *= d;
|
||||
}
|
||||
inline void next() {
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
||||
}
|
||||
device const Block * x;
|
||||
short il;
|
||||
float d;
|
||||
};
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename T, typename simdgroup_T8x8, typename Dequantizer>
|
||||
kernel void kernel_mul_mm(device const uchar * src0,
|
||||
@@ -7821,7 +7845,6 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get
|
||||
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_tn, QK_NL, dequantize_iq2_tn>;
|
||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
@@ -7843,6 +7866,7 @@ template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get
|
||||
template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_bn, 4, dequantize_iq1_bn>;
|
||||
template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_bn, 4, dequantize_iq2_bn>;
|
||||
template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerIQ1TN<float4x4>>;
|
||||
template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerIQ2TN<float4x4>>;
|
||||
|
||||
//
|
||||
// matrix-matrix multiplication
|
||||
@@ -7862,7 +7886,6 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_m
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>>;
|
||||
template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_K, QK_NL, dequantize_q4_K>>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_K, QK_NL, dequantize_q5_K>>;
|
||||
@@ -7884,6 +7907,7 @@ template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_m
|
||||
template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_bn, 4, dequantize_iq1_bn>>;
|
||||
template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_bn, 4, dequantize_iq2_bn>>;
|
||||
template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerIQ1TN<half4x4>>;
|
||||
template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerIQ2TN<half4x4>>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
@@ -7900,7 +7924,6 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_1, 2, dequantize_q5_1>>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q8_0, 2, dequantize_q8_0>>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q2_K, QK_NL, dequantize_q2_K>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q3_K, QK_NL, dequantize_q3_K>>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_K, QK_NL, dequantize_q4_K>>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_K, QK_NL, dequantize_q5_K>>;
|
||||
@@ -7922,6 +7945,7 @@ template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel
|
||||
template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq5_k, QK_NL, dequantize_iq5_k>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq6_k, QK_NL, dequantize_iq6_k>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerIQ1TN<half4x4>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerIQ2TN<half4x4>>;
|
||||
|
||||
//
|
||||
// matrix-vector multiplication
|
||||
|
||||
Reference in New Issue
Block a user