diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 6c2e1bcd..a05a890e 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6931,18 +6931,23 @@ void kernel_mul_mv_iq4_kt_f32_impl( float4 v[2]; thread uint32_t * u32 = (thread uint32_t *)v; - float drow[2*N_DST]; + //float drow[2*N_DST]; + //for (int row = 0; row < N_DST; ++row) { + // device const float * dptr = (device const float *)(cx + row*row_size); + // drow[2*row+0] = dptr[0] * 31.75f * 1.01f; + // drow[2*row+1] = dptr[1]; + //} + float drow[N_DST]; for (int row = 0; row < N_DST; ++row) { device const float * dptr = (device const float *)(cx + row*row_size); - drow[2*row+0] = dptr[0] * 31.75f * 1.01f; - drow[2*row+1] = dptr[1]; + drow[row] = dptr[0] * 31.75f * 1.01f; } device const block_iq4_kt * x = (device const block_iq4_kt *)(cx + 2*sizeof(float)); for (int ib = ix; ib < nb; ib += 2) { - auto sumy = y4[0] + y4[1] + y4[2] + y4[3]; + //auto sumy = y4[0] + y4[1] + y4[2] + y4[3]; device const uint32_t * shb = x[ib].qs; @@ -6951,7 +6956,7 @@ void kernel_mul_mv_iq4_kt_f32_impl( device const uint8_t * ql = (device const uint8_t *)(shb + 8); device const uint8_t * qh = ql + 64; - const float ls = drow[2*row] * (((shb[it/2] & 0xff) >> 1) - 64); + const float ls = drow[row] * (((shb[it/2] & 0xff) >> 1) - 64); const int jj = 8*(it/2) + 4*(it%2); ql += jj; @@ -6969,7 +6974,8 @@ void kernel_mul_mv_iq4_kt_f32_impl( } sum *= ls; - sumf[row] += sum[0] + sum[1] + sum[2] + sum[3] + drow[2*row+1]*(sumy[0] + sumy[1] + sumy[2] + sumy[3]); + //sumf[row] += sum[0] + sum[1] + sum[2] + sum[3] + drow[2*row+1]*(sumy[0] + sumy[1] + sumy[2] + sumy[3]); + sumf[row] += sum[0] + sum[1] + sum[2] + sum[3]; shb += row_size/4; @@ -8574,13 +8580,13 @@ void dequantize_iq3_kt(device const block_iq3_kt * x, short il, thread type4x4 & } } -void dequantize_iq4_kt(device const block_iq4_kt * x, short il, thread float4x4 & reg) { +void dequantize_iq4_kt(device const block_iq4_kt * x, short il, float d, thread float4x4 & reg) { // il is 0...15 for QK_K = 256 int ib32 = il/2; device const uint32_t * shb = x->qs; device const uint8_t * ql = (device const uint8_t *)(shb + 8); device const uint8_t * qh = ql + 64; - float scale = (((shb[ib32] & 0xff) >> 1) - 64); + float scale = d * (((shb[ib32] & 0xff) >> 1) - 64); const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); const int jj = ib32*8 + 4*(il%2); @@ -8592,8 +8598,8 @@ void dequantize_iq4_kt(device const block_iq4_kt * x, short il, thread float4x4 for (int i = 0; i < 4; ++i) { uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset; - auto v = Trellis::gen4(idx); - reg[i] = {(float)v[0]*scale, (float)v[1]*scale, (float)v[2]*scale, (float)v[3]*scale }; + auto v = (float4)Trellis::gen4(idx); + reg[i] = v * scale; } } @@ -8919,8 +8925,9 @@ struct DequantizerRST4 { Scale d[2]; }; -template +template struct DequantizerKT4 { + using Block = block_iq4_kt; using type4x4 = T4x4; DequantizerKT4(device const char * cx, short il = 0) : il(il) { device const float * dptr = (device const float *)cx; @@ -8930,13 +8937,13 @@ struct DequantizerKT4 { } inline void convert(thread T4x4& t) const { float4x4 tmp; - dequantize(x, il, tmp); - for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]*d[0] + d[1]; + dequantize_iq4_kt(x, il, d[0], tmp); + for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; } inline void convert(int64_t ind, thread T4x4& t) { float4x4 tmp; - dequantize(x + ind/nl, ind%nl, tmp); - for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]*d[0] + d[1]; + dequantize_iq4_kt(x + ind/nl, ind%nl, d[0], tmp); + for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; } inline void next() { il = (il + 2 < nl) ? il + 2 : il % 2; @@ -9389,7 +9396,7 @@ template [[host_name("kernel_get_rows_iq4_kss")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2>; template [[host_name("kernel_get_rows_iq2_kt")]] kernel get_rows_q_t kernel_get_rows_q2>; template [[host_name("kernel_get_rows_iq3_kt")]] kernel get_rows_q_t kernel_get_rows_q2>; -template [[host_name("kernel_get_rows_iq4_kt")]] kernel get_rows_q_t kernel_get_rows_q2>; +template [[host_name("kernel_get_rows_iq4_kt")]] kernel get_rows_q_t kernel_get_rows_q2>; // // matrix-matrix multiplication @@ -9436,8 +9443,7 @@ template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm, float>; template [[host_name("kernel_mul_mm_iq2_kt_f32")]] kernel mat_mm_t kernel_mul_mm, float>; template [[host_name("kernel_mul_mm_iq3_kt_f32")]] kernel mat_mm_t kernel_mul_mm, float>; -//template [[host_name("kernel_mul_mm_iq4_kt_f32")]] kernel mat_mm_t kernel_mul_mm, float>; -template [[host_name("kernel_mul_mm_iq4_kt_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_kt_f32")]] kernel mat_mm_t kernel_mul_mm, float>; template [[host_name("kernel_mul_mm_f32_f16")]] kernel mat_mm_t kernel_mul_mm, half>; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mat_mm_t kernel_mul_mm, half>; @@ -9475,8 +9481,7 @@ template [[host_name("kernel_mul_mm_iq4_kss_f16")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq2_ks_f16")]] kernel mat_mm_t kernel_mul_mm, half>; template [[host_name("kernel_mul_mm_iq2_kt_f16")]] kernel mat_mm_t kernel_mul_mm, half>; template [[host_name("kernel_mul_mm_iq3_kt_f16")]] kernel mat_mm_t kernel_mul_mm, half>; -//template [[host_name("kernel_mul_mm_iq4_kt_f16")]] kernel mat_mm_t kernel_mul_mm, half>; -template [[host_name("kernel_mul_mm_iq4_kt_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_kt_f16")]] kernel mat_mm_t kernel_mul_mm, half>; // @@ -9521,7 +9526,7 @@ template [[host_name("kernel_mul_mm_id_iq4_kss_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_iq2_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_iq3_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; -template [[host_name("kernel_mul_mm_id_iq4_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; +template [[host_name("kernel_mul_mm_id_iq4_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; // // matrix-vector multiplication