mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
iq5_ks: Metal dequantize
This commit is contained in:
@@ -107,6 +107,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_KS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS,
|
||||
@@ -150,6 +151,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KSS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32,
|
||||
@@ -186,6 +188,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32,
|
||||
@@ -219,6 +222,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32,
|
||||
@@ -252,6 +256,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16,
|
||||
@@ -285,6 +290,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32,
|
||||
@@ -734,6 +740,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, get_rows_iq4_ks, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS, get_rows_iq4_kss, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_KS, get_rows_iq5_ks, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K, get_rows_iq2_k, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS, get_rows_iq2_ks, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K, get_rows_iq3_k, true);
|
||||
@@ -776,6 +783,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, mul_mv_iq4_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KSS_F32, mul_mv_iq4_kss_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_KS_F32, mul_mv_iq5_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32, mul_mv_iq2_k_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KS_F32, mul_mv_iq2_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32, mul_mv_iq3_k_f32, ctx->support_simdgroup_reduction);
|
||||
@@ -812,6 +820,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, mul_mv_id_iq4_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32, mul_mv_id_iq4_kss_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_KS_F32, mul_mv_id_iq5_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32, mul_mv_id_iq2_k_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32, mul_mv_id_iq2_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32, mul_mv_id_iq3_k_f32, ctx->support_simdgroup_reduction);
|
||||
@@ -845,6 +854,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, mul_mm_iq4_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32, mul_mm_iq4_kss_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F32, mul_mm_iq5_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32, mul_mm_iq2_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32, mul_mm_iq2_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32, mul_mm_iq3_k_f32, ctx->support_simdgroup_mm);
|
||||
@@ -878,6 +888,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, mul_mm_iq4_xs_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, mul_mm_iq4_ks_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16, mul_mm_iq4_kss_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F16, mul_mm_iq5_ks_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16, mul_mm_iq2_k_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16, mul_mm_iq2_ks_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16, mul_mm_iq3_k_f16, ctx->support_simdgroup_mm);
|
||||
@@ -911,6 +922,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, mul_mm_id_iq4_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32, mul_mm_id_iq4_kss_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_KS_F32, mul_mm_id_iq5_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32, mul_mm_id_iq2_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32, mul_mm_id_iq2_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32, mul_mm_id_iq3_k_f32, ctx->support_simdgroup_mm);
|
||||
@@ -2123,6 +2135,7 @@ static void ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break;
|
||||
@@ -2161,6 +2174,7 @@ static void ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16].pipeline; break;
|
||||
case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16 ].pipeline; break;
|
||||
@@ -2384,6 +2398,12 @@ static void ggml_metal_encode_node(
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KSS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_KS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
{
|
||||
nth0 = 4;
|
||||
@@ -2471,8 +2491,9 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
|
||||
src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
|
||||
src0t == GGML_TYPE_IQ4_KSS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float);
|
||||
src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float)
|
||||
: src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ5_KS ? 64*sizeof(float) : 32*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
@@ -2568,6 +2589,7 @@ static void ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break;
|
||||
@@ -2775,6 +2797,12 @@ static void ggml_metal_encode_node(
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_KS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
{
|
||||
nth0 = 4;
|
||||
@@ -2873,8 +2901,9 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
|
||||
src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
|
||||
src0t == GGML_TYPE_IQ4_KSS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float);
|
||||
src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float)
|
||||
: src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ5_KS ? 64*sizeof(float) : 32*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
@@ -2926,6 +2955,7 @@ static void ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break;
|
||||
case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_KS ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break;
|
||||
|
||||
@@ -6276,6 +6276,113 @@ void kernel_mul_mv_iq4_ks_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
void kernel_mul_mv_iq5_ks_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
int64_t ne00,
|
||||
int64_t ne01,
|
||||
int64_t ne02,
|
||||
int64_t ne10,
|
||||
int64_t ne12,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
uint r2,
|
||||
uint r3,
|
||||
threadgroup int8_t * shared_values_i8,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
|
||||
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
||||
const int nb = ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
const int first_row = (r0 * 2 + sgitg) * 2;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint row_size = 4 + nb*sizeof(block_iq4_ks);
|
||||
const uint offset0 = (i12/r2)*ne01 + (i13/r3)*(ne01*ne02);
|
||||
device const char * cx = (device const char *)src0 + (first_row + offset0)*row_size;
|
||||
device const float * y = (device const float *)src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
const int ix = tiisg/16; // 0 or 1
|
||||
const int it = tiisg%16; // 0...15
|
||||
const int ib = it/2;
|
||||
const int il = it%2;
|
||||
|
||||
shared_values[tiisg] = kvalues_iq4k_f[tiisg];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float2 sumf = 0.f;
|
||||
float d[2];
|
||||
|
||||
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
||||
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
||||
|
||||
float4 qf1, qf2;
|
||||
|
||||
device const float * dptr = (device const float *)cx;
|
||||
d[0] = *dptr;
|
||||
device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1) + ix;
|
||||
dptr += row_size/4;
|
||||
d[1] = *dptr;
|
||||
|
||||
for (int ibl = ix; ibl < nb; ibl += 2) {
|
||||
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||
|
||||
device const uint8_t * scales = x->scales;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
||||
threadgroup const float * block_values = shared_values + ((scales[ib] & 1) << 4);
|
||||
const float ls = ((scales[ib] & 254) - 127);
|
||||
|
||||
device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 4*ib + 2*il;
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
|
||||
aux32[0] = q4[0] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
|
||||
qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]};
|
||||
acc1 += yl[0] * qf1;
|
||||
acc2 += yl[1] * qf2;
|
||||
|
||||
aux32[0] = q4[1] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
|
||||
qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]};
|
||||
acc1 += yl[2] * qf1;
|
||||
acc2 += yl[3] * qf2;
|
||||
|
||||
acc1 += acc2;
|
||||
|
||||
sumf[row] += d[row] * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
||||
|
||||
scales += row_size;
|
||||
|
||||
}
|
||||
|
||||
yb += 2 * QK_K;
|
||||
x += 2;
|
||||
}
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
if (tiisg < 2) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq4_kss_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@@ -7315,6 +7422,35 @@ kernel void kernel_mul_mv_iq4_ks_f32(
|
||||
kernel_mul_mv_iq4_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq5_ks_f32")]]
|
||||
kernel void kernel_mul_mv_iq5_ks_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq5_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq4_kss_f32")]]
|
||||
kernel void kernel_mul_mv_iq4_kss_f32(
|
||||
device const void * src0,
|
||||
@@ -7930,6 +8066,25 @@ void dequantize_iq4_ks(device const block_iq4_ks * xb, short il, thread type4x4
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq5_ks(device const block_iq5_ks * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 8*(ib32/2) + 4*(il%2);
|
||||
device const uint32_t * qh = (device const uint32_t *)xb->qh + 4*(il%2);
|
||||
const float ls = (xb->scales[ib32] & 254) - 127;
|
||||
constant float * values = kvalues_iq5k_f + ((xb->scales[ib32] & 1) << 5);
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
aux32 = ((q4[i] >> 4*(ib32%2)) & 0x0f0f0f0f) | (((qh[i] >> ib32) & 0x01010101) << 4);
|
||||
reg[i][0] = ls * values[q8[0]];
|
||||
reg[i][1] = ls * values[q8[1]];
|
||||
reg[i][2] = ls * values[q8[2]];
|
||||
reg[i][3] = ls * values[q8[3]];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq4_kss(device const block_iq4_kss * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
@@ -8687,6 +8842,7 @@ template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get
|
||||
template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRSBN<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>;
|
||||
template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRSBN<float4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>;
|
||||
template [[host_name("kernel_get_rows_iq4_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>;
|
||||
template [[host_name("kernel_get_rows_iq5_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>>;
|
||||
template [[host_name("kernel_get_rows_iq4_kss")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
|
||||
template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
|
||||
|
||||
@@ -8730,6 +8886,7 @@ template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_m
|
||||
template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq5_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, float>;
|
||||
|
||||
@@ -8764,6 +8921,7 @@ template [[host_name("kernel_mul_mm_iq6_k_f16")]] kernel mat_mm_t kernel_mul_m
|
||||
template [[host_name("kernel_mul_mm_iq1_bn_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq2_bn_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq4_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq5_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq4_kss_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq2_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, half>;
|
||||
|
||||
@@ -8805,6 +8963,7 @@ template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel
|
||||
template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq5_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_kss_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
|
||||
|
||||
@@ -9021,6 +9180,7 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t
|
||||
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_ks_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq5_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq5_ks_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_kss_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_kss_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_k_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_ks_f32_impl>>;
|
||||
|
||||
Reference in New Issue
Block a user