mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
iq4_kt: Metal still not working
This commit is contained in:
@@ -6931,18 +6931,23 @@ void kernel_mul_mv_iq4_kt_f32_impl(
|
|||||||
float4 v[2];
|
float4 v[2];
|
||||||
thread uint32_t * u32 = (thread uint32_t *)v;
|
thread uint32_t * u32 = (thread uint32_t *)v;
|
||||||
|
|
||||||
float drow[2*N_DST];
|
//float drow[2*N_DST];
|
||||||
|
//for (int row = 0; row < N_DST; ++row) {
|
||||||
|
// device const float * dptr = (device const float *)(cx + row*row_size);
|
||||||
|
// drow[2*row+0] = dptr[0] * 31.75f * 1.01f;
|
||||||
|
// drow[2*row+1] = dptr[1];
|
||||||
|
//}
|
||||||
|
float drow[N_DST];
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
device const float * dptr = (device const float *)(cx + row*row_size);
|
device const float * dptr = (device const float *)(cx + row*row_size);
|
||||||
drow[2*row+0] = dptr[0] * 31.75f * 1.01f;
|
drow[row] = dptr[0] * 31.75f * 1.01f;
|
||||||
drow[2*row+1] = dptr[1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
device const block_iq4_kt * x = (device const block_iq4_kt *)(cx + 2*sizeof(float));
|
device const block_iq4_kt * x = (device const block_iq4_kt *)(cx + 2*sizeof(float));
|
||||||
|
|
||||||
for (int ib = ix; ib < nb; ib += 2) {
|
for (int ib = ix; ib < nb; ib += 2) {
|
||||||
|
|
||||||
auto sumy = y4[0] + y4[1] + y4[2] + y4[3];
|
//auto sumy = y4[0] + y4[1] + y4[2] + y4[3];
|
||||||
|
|
||||||
device const uint32_t * shb = x[ib].qs;
|
device const uint32_t * shb = x[ib].qs;
|
||||||
|
|
||||||
@@ -6951,7 +6956,7 @@ void kernel_mul_mv_iq4_kt_f32_impl(
|
|||||||
device const uint8_t * ql = (device const uint8_t *)(shb + 8);
|
device const uint8_t * ql = (device const uint8_t *)(shb + 8);
|
||||||
device const uint8_t * qh = ql + 64;
|
device const uint8_t * qh = ql + 64;
|
||||||
|
|
||||||
const float ls = drow[2*row] * (((shb[it/2] & 0xff) >> 1) - 64);
|
const float ls = drow[row] * (((shb[it/2] & 0xff) >> 1) - 64);
|
||||||
|
|
||||||
const int jj = 8*(it/2) + 4*(it%2);
|
const int jj = 8*(it/2) + 4*(it%2);
|
||||||
ql += jj;
|
ql += jj;
|
||||||
@@ -6969,7 +6974,8 @@ void kernel_mul_mv_iq4_kt_f32_impl(
|
|||||||
}
|
}
|
||||||
sum *= ls;
|
sum *= ls;
|
||||||
|
|
||||||
sumf[row] += sum[0] + sum[1] + sum[2] + sum[3] + drow[2*row+1]*(sumy[0] + sumy[1] + sumy[2] + sumy[3]);
|
//sumf[row] += sum[0] + sum[1] + sum[2] + sum[3] + drow[2*row+1]*(sumy[0] + sumy[1] + sumy[2] + sumy[3]);
|
||||||
|
sumf[row] += sum[0] + sum[1] + sum[2] + sum[3];
|
||||||
|
|
||||||
shb += row_size/4;
|
shb += row_size/4;
|
||||||
|
|
||||||
@@ -8574,13 +8580,13 @@ void dequantize_iq3_kt(device const block_iq3_kt * x, short il, thread type4x4 &
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void dequantize_iq4_kt(device const block_iq4_kt * x, short il, thread float4x4 & reg) {
|
void dequantize_iq4_kt(device const block_iq4_kt * x, short il, float d, thread float4x4 & reg) {
|
||||||
// il is 0...15 for QK_K = 256
|
// il is 0...15 for QK_K = 256
|
||||||
int ib32 = il/2;
|
int ib32 = il/2;
|
||||||
device const uint32_t * shb = x->qs;
|
device const uint32_t * shb = x->qs;
|
||||||
device const uint8_t * ql = (device const uint8_t *)(shb + 8);
|
device const uint8_t * ql = (device const uint8_t *)(shb + 8);
|
||||||
device const uint8_t * qh = ql + 64;
|
device const uint8_t * qh = ql + 64;
|
||||||
float scale = (((shb[ib32] & 0xff) >> 1) - 64);
|
float scale = d * (((shb[ib32] & 0xff) >> 1) - 64);
|
||||||
const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15);
|
const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15);
|
||||||
|
|
||||||
const int jj = ib32*8 + 4*(il%2);
|
const int jj = ib32*8 + 4*(il%2);
|
||||||
@@ -8592,8 +8598,8 @@ void dequantize_iq4_kt(device const block_iq4_kt * x, short il, thread float4x4
|
|||||||
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset;
|
uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset;
|
||||||
auto v = Trellis::gen4(idx);
|
auto v = (float4)Trellis::gen4(idx);
|
||||||
reg[i] = {(float)v[0]*scale, (float)v[1]*scale, (float)v[2]*scale, (float)v[3]*scale };
|
reg[i] = v * scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -8919,8 +8925,9 @@ struct DequantizerRST4 {
|
|||||||
Scale d[2];
|
Scale d[2];
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T4x4, typename Block, int nl, void (*dequantize)(device const Block *, short, thread float4x4&)>
|
template <typename T4x4, int nl>
|
||||||
struct DequantizerKT4 {
|
struct DequantizerKT4 {
|
||||||
|
using Block = block_iq4_kt;
|
||||||
using type4x4 = T4x4;
|
using type4x4 = T4x4;
|
||||||
DequantizerKT4(device const char * cx, short il = 0) : il(il) {
|
DequantizerKT4(device const char * cx, short il = 0) : il(il) {
|
||||||
device const float * dptr = (device const float *)cx;
|
device const float * dptr = (device const float *)cx;
|
||||||
@@ -8930,13 +8937,13 @@ struct DequantizerKT4 {
|
|||||||
}
|
}
|
||||||
inline void convert(thread T4x4& t) const {
|
inline void convert(thread T4x4& t) const {
|
||||||
float4x4 tmp;
|
float4x4 tmp;
|
||||||
dequantize(x, il, tmp);
|
dequantize_iq4_kt(x, il, d[0], tmp);
|
||||||
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]*d[0] + d[1];
|
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j];
|
||||||
}
|
}
|
||||||
inline void convert(int64_t ind, thread T4x4& t) {
|
inline void convert(int64_t ind, thread T4x4& t) {
|
||||||
float4x4 tmp;
|
float4x4 tmp;
|
||||||
dequantize(x + ind/nl, ind%nl, tmp);
|
dequantize_iq4_kt(x + ind/nl, ind%nl, d[0], tmp);
|
||||||
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]*d[0] + d[1];
|
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j];
|
||||||
}
|
}
|
||||||
inline void next() {
|
inline void next() {
|
||||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||||
@@ -9389,7 +9396,7 @@ template [[host_name("kernel_get_rows_iq4_kss")]] kernel get_rows_q_t kernel_get
|
|||||||
template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
|
template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
|
||||||
template [[host_name("kernel_get_rows_iq2_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>>;
|
template [[host_name("kernel_get_rows_iq2_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>>;
|
||||||
template [[host_name("kernel_get_rows_iq3_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>>;
|
template [[host_name("kernel_get_rows_iq3_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>>;
|
||||||
template [[host_name("kernel_get_rows_iq4_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerKT4<float4x4, block_iq4_kt, 16, dequantize_iq4_kt>>;
|
template [[host_name("kernel_get_rows_iq4_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerKT4<float4x4, 16>>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// matrix-matrix multiplication
|
// matrix-matrix multiplication
|
||||||
@@ -9436,8 +9443,7 @@ template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_m
|
|||||||
template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, float>;
|
template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, float>;
|
||||||
template [[host_name("kernel_mul_mm_iq2_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>, float>;
|
template [[host_name("kernel_mul_mm_iq2_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>, float>;
|
||||||
template [[host_name("kernel_mul_mm_iq3_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>, float>;
|
template [[host_name("kernel_mul_mm_iq3_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>, float>;
|
||||||
//template [[host_name("kernel_mul_mm_iq4_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerKT4<half4x4, block_iq4_kt, 16, dequantize_iq4_kt>, float>;
|
template [[host_name("kernel_mul_mm_iq4_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerKT4<half4x4, 16>, float>;
|
||||||
template [[host_name("kernel_mul_mm_iq4_kt_f32")]] kernel mat_mm_t kernel_mul_mm<float, simdgroup_float8x8, DequantizerKT4<float4x4, block_iq4_kt, 16, dequantize_iq4_kt>, float>;
|
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>, half>;
|
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>, half>;
|
||||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>, half>;
|
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>, half>;
|
||||||
@@ -9475,8 +9481,7 @@ template [[host_name("kernel_mul_mm_iq4_kss_f16")]] kernel mat_mm_t kernel_mul_m
|
|||||||
template [[host_name("kernel_mul_mm_iq2_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, half>;
|
template [[host_name("kernel_mul_mm_iq2_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, half>;
|
||||||
template [[host_name("kernel_mul_mm_iq2_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>, half>;
|
template [[host_name("kernel_mul_mm_iq2_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>, half>;
|
||||||
template [[host_name("kernel_mul_mm_iq3_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>, half>;
|
template [[host_name("kernel_mul_mm_iq3_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>, half>;
|
||||||
//template [[host_name("kernel_mul_mm_iq4_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerKT4<half4x4, block_iq4_kt, 16, dequantize_iq4_kt>, half>;
|
template [[host_name("kernel_mul_mm_iq4_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerKT4<half4x4, 16>, half>;
|
||||||
template [[host_name("kernel_mul_mm_iq4_kt_f16")]] kernel mat_mm_t kernel_mul_mm<float, simdgroup_float8x8, DequantizerKT4<float4x4, block_iq4_kt, 16, dequantize_iq4_kt>, half>;
|
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -9521,7 +9526,7 @@ template [[host_name("kernel_mul_mm_id_iq4_kss_f32")]] kernel mat_mm_id_t kernel
|
|||||||
template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
|
template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq2_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>>;
|
template [[host_name("kernel_mul_mm_id_iq2_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq3_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>>;
|
template [[host_name("kernel_mul_mm_id_iq3_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq4_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerKT4<half4x4, block_iq4_kt, 16, dequantize_iq4_kt>>;
|
template [[host_name("kernel_mul_mm_id_iq4_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerKT4<half4x4, 16>>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// matrix-vector multiplication
|
// matrix-vector multiplication
|
||||||
|
|||||||
Reference in New Issue
Block a user