mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-20 22:49:31 +00:00
iq2_kl: Metal dequantize
This commit is contained in:
@@ -112,6 +112,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KL,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K,
|
||||
@@ -159,6 +160,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_K_F32,
|
||||
@@ -200,6 +202,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32,
|
||||
@@ -238,6 +241,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32,
|
||||
@@ -276,6 +280,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KL_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16,
|
||||
@@ -314,6 +319,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32,
|
||||
@@ -768,6 +774,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_KS, get_rows_iq5_ks, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K, get_rows_iq2_k, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS, get_rows_iq2_ks, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KL, get_rows_iq2_kl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K, get_rows_iq3_k, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K, get_rows_iq4_k, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K, get_rows_iq5_k, true);
|
||||
@@ -815,6 +822,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_KS_F32, mul_mv_iq5_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32, mul_mv_iq2_k_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KS_F32, mul_mv_iq2_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KL_F32, mul_mv_iq2_kl_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32, mul_mv_iq3_k_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32, mul_mv_iq4_k_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_K_F32, mul_mv_iq5_k_f32, ctx->support_simdgroup_reduction);
|
||||
@@ -856,6 +864,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_KS_F32, mul_mv_id_iq5_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32, mul_mv_id_iq2_k_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32, mul_mv_id_iq2_ks_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KL_F32, mul_mv_id_iq2_kl_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32, mul_mv_id_iq3_k_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32, mul_mv_id_iq4_k_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32, mul_mv_id_iq5_k_f32, ctx->support_simdgroup_reduction);
|
||||
@@ -894,6 +903,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F32, mul_mm_iq5_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32, mul_mm_iq2_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32, mul_mm_iq2_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KL_F32, mul_mm_iq2_kl_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32, mul_mm_iq3_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, mul_mm_iq4_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, mul_mm_iq5_k_f32, ctx->support_simdgroup_mm);
|
||||
@@ -932,6 +942,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F16, mul_mm_iq5_ks_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16, mul_mm_iq2_k_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16, mul_mm_iq2_ks_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KL_F16, mul_mm_iq2_kl_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16, mul_mm_iq3_k_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16, mul_mm_iq4_k_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16, mul_mm_iq5_k_f16, ctx->support_simdgroup_mm);
|
||||
@@ -970,6 +981,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_KS_F32, mul_mm_id_iq5_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32, mul_mm_id_iq2_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32, mul_mm_id_iq2_ks_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KL_F32, mul_mm_id_iq2_kl_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32, mul_mm_id_iq3_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32, mul_mm_id_iq4_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32, mul_mm_id_iq5_k_f32, ctx->support_simdgroup_mm);
|
||||
@@ -2187,6 +2199,7 @@ static void ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break;
|
||||
@@ -2230,6 +2243,7 @@ static void ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KL_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16 ].pipeline; break;
|
||||
@@ -2478,6 +2492,12 @@ static void ggml_metal_encode_node(
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_KL:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KL_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
{
|
||||
nth0 = 4;
|
||||
@@ -2555,7 +2575,8 @@ static void ggml_metal_encode_node(
|
||||
src0t == GGML_TYPE_IQ2_KT|| src0t == GGML_TYPE_IQ3_KT) { //|| src0t == GGML_TYPE_IQ4_KT) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS) {
|
||||
else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS ||
|
||||
src0t == GGML_TYPE_IQ2_KL) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float)
|
||||
: src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS ? 32*sizeof(float) : 16*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
@@ -2675,6 +2696,7 @@ static void ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break;
|
||||
@@ -2907,6 +2929,12 @@ static void ggml_metal_encode_node(
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_KL:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KL_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
{
|
||||
nth0 = 4;
|
||||
@@ -2995,7 +3023,8 @@ static void ggml_metal_encode_node(
|
||||
src0t == GGML_TYPE_IQ2_KT|| src0t == GGML_TYPE_IQ3_KT) { //|| src0t == GGML_TYPE_IQ4_KT) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS) {
|
||||
else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS ||
|
||||
src0t == GGML_TYPE_IQ2_KL) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float)
|
||||
: src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS ? 32*sizeof(float) : 16*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
@@ -3071,6 +3100,7 @@ static void ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_KS ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KL ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break;
|
||||
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break;
|
||||
|
||||
@@ -7231,6 +7231,156 @@ kernel void kernel_mul_mv_iq2_ks_f32(
|
||||
kernel_mul_mv_iq2_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq2_kl_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
int64_t ne00,
|
||||
int64_t ne01,
|
||||
int64_t ne02,
|
||||
int64_t ne10,
|
||||
int64_t ne12,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
uint r2,
|
||||
uint r3,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
const uint row_size = 2 + nb*sizeof(block_iq2_kl);
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint offset0 = (i12/r2)*(ne01) + (i13/r3)*(ne01*ne02);
|
||||
|
||||
device const char * cx = (device const char *) src0 + (first_row + offset0)*row_size;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f};
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
const int iq = it/4; // 0 or 1
|
||||
const int ir = it%4; // 0...3
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
||||
|
||||
threadgroup float * all_values = (threadgroup float *)shared_values + 32*sgitg;
|
||||
{
|
||||
//int row = tiisg%N_DST;
|
||||
//device const half * dptr = (device const half *)(cx + row*row_size);
|
||||
//const float d = *dptr;
|
||||
//all_values[8*row + tiisg/N_DST] = d*iq2nl_values[tiisg/N_DST];
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
int row = tiisg/8;
|
||||
int pos = tiisg%8;
|
||||
device const half * dptr = (device const half *)(cx + row*row_size);
|
||||
const float d = *dptr;
|
||||
all_values[8*row + pos] = d*kvalues_iq2k_f[pos];
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
cx += sizeof(half);
|
||||
|
||||
uint32_t q32[2];
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0];
|
||||
yl[i+ 8] = y4[i+32];
|
||||
yl[i+16] = y4[i+64];
|
||||
yl[i+24] = y4[i+96];
|
||||
}
|
||||
|
||||
device const block_iq2_ks * x = (device const block_iq2_ks *)cx + ib;
|
||||
device const uint16_t * q16 = (device const uint16_t *)x->qs + 16*iq + 4*ir;
|
||||
device const uint16_t * sc = (device const uint16_t *)x->scales;
|
||||
device const uint16_t * ex = (device const uint16_t *)&x->extra;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
threadgroup const float * row_values = all_values + 8*row;
|
||||
|
||||
uint32_t sc32 = (sc[iq] | (sc[iq] << 12)) & 0x0f0f0f0f;
|
||||
thread const int8_t * s8 = (thread const int8_t *)&sc32;
|
||||
|
||||
q32[0] = q16[0] | (q16[1] << 16);
|
||||
q32[1] = q16[2] | (q16[3] << 16);
|
||||
|
||||
uint8_t extra = ex[0] << 4*(1-iq);
|
||||
|
||||
float4 acc = {0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
threadgroup const float * values = row_values + ((extra >> (2 + l)) & 4);
|
||||
aux32[0] = (q32[0] >> 2*l) & 0x03030303;
|
||||
aux32[1] = (q32[1] >> 2*l) & 0x03030303;
|
||||
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
|
||||
}
|
||||
extra = ex[0] >> (8 + 4*iq);
|
||||
sumf[row] += acc[0] * (s8[0] - (extra & 1 ? 0 : 16)) + acc[1] * (s8[2] - (extra & 2 ? 0 : 16))
|
||||
+ acc[2] * (s8[1] - (extra & 4 ? 0 : 16)) + acc[3] * (s8[3] - (extra & 8 ? 0 : 16));
|
||||
|
||||
q16 += row_size/2;
|
||||
sc += row_size/2;
|
||||
ex += row_size/2;
|
||||
|
||||
}
|
||||
|
||||
y4 += 4 * QK_K;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row += 2) {
|
||||
float2 tmp = {sumf[row], sumf[row+1]};
|
||||
tmp = simd_sum(tmp);
|
||||
if (tiisg < 2) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = tmp[tiisg];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq2_kl_f32")]]
|
||||
kernel void kernel_mul_mv_iq2_kl_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_kl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq3_k_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@@ -8820,6 +8970,28 @@ void dequantize_iq2_ks(device const block_iq2_ks * xb, short il, thread type4x4
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_kl(device const block_iq2_kl * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256
|
||||
const short ib32 = il/2;
|
||||
device const uint16_t * ql = (device const uint16_t * )xb->qs + 8*(ib32/2) + 4*(il%2);
|
||||
device const uint16_t * qh = (device const uint16_t * )xb->qh + 4*(il%2);
|
||||
|
||||
half d = (int16_t(((xb->scales_l[ib32%4] >> 4*(ib32/4)) & 0xf) | (((xb->scales_h >> 2*ib32) & 0x3) << 4)) - 32);
|
||||
|
||||
thread uint16_t aux16;
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)&aux16;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
aux16 = ((ql[i] >> 4*(ib32%2)) & 0x0f0f) | (((qh[i] >> ib32) & 0x0101) << 4);
|
||||
constant const int8_t * val1 = (constant const int8_t *)(iq2kl_values + aux8[0]);
|
||||
constant const int8_t * val2 = (constant const int8_t *)(iq2kl_values + aux8[1]);
|
||||
reg[i][0] = d * val1[0];
|
||||
reg[i][1] = d * val1[1];
|
||||
reg[i][2] = d * val2[0];
|
||||
reg[i][3] = d * val2[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256
|
||||
@@ -9596,6 +9768,7 @@ template [[host_name("kernel_get_rows_iq4_ks")]] kernel get_rows_q_t kernel_get
|
||||
template [[host_name("kernel_get_rows_iq5_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>>;
|
||||
template [[host_name("kernel_get_rows_iq4_kss")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
|
||||
template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
|
||||
template [[host_name("kernel_get_rows_iq2_kl")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_kl, half, 16, dequantize_iq2_kl>>;
|
||||
template [[host_name("kernel_get_rows_iq2_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>>;
|
||||
template [[host_name("kernel_get_rows_iq3_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>>;
|
||||
template [[host_name("kernel_get_rows_iq4_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerKT4<float4x4, 16>>;
|
||||
@@ -9644,6 +9817,7 @@ template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_m
|
||||
template [[host_name("kernel_mul_mm_iq5_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq2_kl_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kl, half, 16, dequantize_iq2_kl>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq2_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq3_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>, float>;
|
||||
template [[host_name("kernel_mul_mm_iq4_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerKT4<half4x4, 16>, float>;
|
||||
@@ -9683,6 +9857,7 @@ template [[host_name("kernel_mul_mm_iq4_ks_f16")]] kernel mat_mm_t kernel_mul_m
|
||||
template [[host_name("kernel_mul_mm_iq5_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq4_kss_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq2_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq2_kl_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kl, half, 16, dequantize_iq2_kl>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq2_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq3_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>, half>;
|
||||
template [[host_name("kernel_mul_mm_iq4_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerKT4<half4x4, 16>, half>;
|
||||
@@ -9729,6 +9904,7 @@ template [[host_name("kernel_mul_mm_id_iq4_ks_f32")]] kernel mat_mm_id_t kernel
|
||||
template [[host_name("kernel_mul_mm_id_iq5_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_kss_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_kl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_kl, half, 16, dequantize_iq2_kl>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerKT4<half4x4, 16>>;
|
||||
@@ -9951,6 +10127,7 @@ template [[host_name("kernel_mul_mv_id_iq5_ks_f32")]] kernel kernel_mul_mv_id_t
|
||||
template [[host_name("kernel_mul_mv_id_iq4_kss_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_kss_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_k_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_ks_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_kl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_kl_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_kt_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_kt_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq3_kt_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_kt_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_kt_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_kt_f32_impl>>;
|
||||
|
||||
Reference in New Issue
Block a user