mxfp4: Metal

This commit is contained in:
Iwan Kawrakow
2025-08-08 18:17:34 +03:00
parent 19d6799652
commit 7388d9be8d
2 changed files with 184 additions and 2 deletions

View File

@@ -105,6 +105,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS,
@@ -153,6 +154,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32,
@@ -195,6 +197,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32,
@@ -234,6 +237,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32,
@@ -273,6 +277,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16,
@@ -312,6 +317,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32,
@@ -767,6 +773,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS, get_rows_iq3_ks, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, get_rows_iq4_ks, true);
@@ -815,6 +822,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KS_F32, mul_mv_iq3_ks_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, mul_mv_iq4_ks_f32, ctx->support_simdgroup_reduction);
@@ -857,6 +865,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KS_F32, mul_mv_id_iq3_ks_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, mul_mv_id_iq4_ks_f32, ctx->support_simdgroup_reduction);
@@ -896,6 +905,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32, mul_mm_iq3_ks_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, mul_mm_iq4_ks_f32, ctx->support_simdgroup_mm);
@@ -935,6 +945,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, mul_mm_iq1_bn_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, mul_mm_iq2_bn_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, mul_mm_iq4_nl_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F16, mul_mm_mxfp4_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, mul_mm_iq4_xs_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16, mul_mm_iq3_ks_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, mul_mm_iq4_ks_f16, ctx->support_simdgroup_mm);
@@ -974,6 +985,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F32, mul_mm_id_mxfp4_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32, mul_mm_id_iq3_ks_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, mul_mm_id_iq4_ks_f32, ctx->support_simdgroup_mm);
@@ -2192,6 +2204,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32 ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break;
@@ -2236,6 +2249,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F16 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16 ].pipeline; break;
case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16 ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16 ].pipeline; break;
@@ -2450,6 +2464,12 @@ static void ggml_metal_encode_node(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
{
nth0 = 4;
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
} break;
case GGML_TYPE_IQ4_XS:
{
nth0 = 4;
@@ -2595,7 +2615,7 @@ static void ggml_metal_encode_node(
}
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS) {
src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS || src0t == GGML_TYPE_MXFP4) {
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float)
: src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ5_KS ? 64*sizeof(float) : 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
@@ -2690,6 +2710,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32 ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break;
@@ -2888,6 +2909,12 @@ static void ggml_metal_encode_node(
nth1 = 2;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
{
nth0 = 32;
nth1 = 2;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
} break;
case GGML_TYPE_IQ4_XS:
{
nth0 = 32;
@@ -3044,7 +3071,7 @@ static void ggml_metal_encode_node(
}
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS) {
src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS || src0t == GGML_TYPE_MXFP4) {
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float)
: src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ5_KS ? 64*sizeof(float) : 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
@@ -3095,6 +3122,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS ].pipeline; break;
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break;

View File

@@ -3975,6 +3975,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
constexpr constant static float kvalues_mxfp4_f[16] = {
0.f, 1.f, 2.f, 3.f, 4.f, 6.f, 8.f, 12.f, 0.f, -1.f, -2.f, -3.f, -4.f, -6.f, -8.f, -12.f
};
constexpr constant static float kvalues_iq4k_f[32] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f,
-123.f, -100.f, -79.f, -61.f, -45.f, -31.f, -18.f, -6.f, 5.f, 17.f, 29.f, 42.f, 57.f, 73.f, 93.f, 117.f,
@@ -6082,6 +6086,104 @@ void kernel_mul_mv_iq4_nl_f32_impl(
}
}
void kernel_mul_mv_mxfp4_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
int64_t ne10,
int64_t ne12,
int64_t ne0,
int64_t ne1,
uint r2,
uint r3,
threadgroup int8_t * shared_values_i8,
uint3 tgpig,
uint tiisg,
uint sgitg) {
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK4_NL;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * 2 + sgitg) * 2;
const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
device const block_mxfp4 * x = (device const block_mxfp4 *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
const int ix = tiisg/2; // 0...15
const int it = tiisg%2; // 0 or 1
shared_values[tiisg] = kvalues_mxfp4_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[2]={0.f}, all_sum;
device const float * yb = y + ix * QK4_NL + it * 8;
uint32_t aux32[2];
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
float4 qf1, qf2;
constexpr uint32_t val[2] = { 0x00200000, 0x00400000 };
union { float f; uint32_t u; } helper;
for (int ib = ix; ib < nb; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
device const block_mxfp4 & xb = x[row*nb + ib];
device const uint8_t * q4 = (device const uint8_t *)(xb.qs + 8*it);
float4 acc1 = {0.f}, acc2 = {0.f};
aux32[0] = q4[0] | (q4[1] << 8) | (q4[2] << 16) | (q4[3] << 24);
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
aux32[0] &= 0x0f0f0f0f;
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
acc1 += yl[0] * qf1;
acc2 += yl[1] * qf2;
aux32[0] = q4[4] | (q4[5] << 8) | (q4[6] << 16) | (q4[7] << 24);
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
aux32[0] &= 0x0f0f0f0f;
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
acc1 += yl[2] * qf1;
acc2 += yl[3] * qf2;
acc1 += acc2;
helper.u = xb.e >= 2 ? (uint32_t)(xb.e - 1) << 23u : val[xb.e];
sumf[row] += helper.f * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
}
yb += 16 * QK4_NL;
}
for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
}
}
void kernel_mul_mv_iq4_xs_f32_impl(
device const void * src0,
device const float * src1,
@@ -8129,6 +8231,35 @@ kernel void kernel_mul_mv_iq4_nl_f32(
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_mxfp4_f32")]]
kernel void kernel_mul_mv_mxfp4_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_mxfp4_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
kernel void kernel_mul_mv_iq4_xs_f32(
device const void * src0,
@@ -8791,6 +8922,24 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
}
}
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
constexpr uint32_t val[2] = { 0x00200000, 0x00400000 };
device const uint8_t * q4 = (device const uint8_t *)xb->qs;
union { float f; uint32_t u; } helper;
helper.u = xb->e >= 2 ? (uint32_t)(xb->e - 1) << 23u : val[xb->e];
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = q4[4*i] | (q4[4*i+1] << 8) | (q4[4*i+2] << 16) | (q4[4*i+3] << 24);
aux32 = (aux32 >> 4*il) & 0x0f0f0f0f;
reg[i][0] = helper.f * kvalues_mxfp4_f[q8[0]];
reg[i][1] = helper.f * kvalues_mxfp4_f[q8[1]];
reg[i][2] = helper.f * kvalues_mxfp4_f[q8[2]];
reg[i][3] = helper.f * kvalues_mxfp4_f[q8[3]];
}
}
template <typename type4x4>
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
@@ -9761,6 +9910,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_get_rows_iq2_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_k, QK_NL, dequantize_iq2_k>;
template [[host_name("kernel_get_rows_iq3_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_k, QK_NL, dequantize_iq3_k>;
@@ -9810,6 +9960,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_s, QK_NL, dequantize_iq1_s>, float>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_m, QK_NL, dequantize_iq1_m>, float>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_nl, 2, dequantize_iq4_nl>, float>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_mxfp4, 2, dequantize_mxfp4>, float>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_xs, QK_NL, dequantize_iq4_xs>, float>;
template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_k, QK_NL, dequantize_iq2_k>, float>;
template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_k, QK_NL, dequantize_iq3_k>, float>;
@@ -9850,6 +10001,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_s, QK_NL, dequantize_iq1_s>, half>;
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_m, QK_NL, dequantize_iq1_m>, half>;
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_nl, 2, dequantize_iq4_nl>, half>;
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_mxfp4, 2, dequantize_mxfp4>, half>;
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_xs, QK_NL, dequantize_iq4_xs>, half>;
template [[host_name("kernel_mul_mm_iq2_k_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_k, QK_NL, dequantize_iq2_k>, half>;
template [[host_name("kernel_mul_mm_iq3_k_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_k, QK_NL, dequantize_iq3_k>, half>;
@@ -9897,6 +10049,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq1_s, QK_NL, dequantize_iq1_s>>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq1_m, QK_NL, dequantize_iq1_m>>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_nl, 2, dequantize_iq4_nl>>;
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_mxfp4, 2, dequantize_mxfp4>>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_xs, QK_NL, dequantize_iq4_xs>>;
template [[host_name("kernel_mul_mm_id_iq2_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_k, QK_NL, dequantize_iq2_k>>;
template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq3_k, QK_NL, dequantize_iq3_k>>;
@@ -10126,6 +10279,7 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq3_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_ks_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_ks_f32_impl>>;