From 7117c23de46f37fe7b8300ec0f6fa6c1ead06e18 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 9 Aug 2025 08:40:18 +0300 Subject: [PATCH] MXFP4 (#682) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mxfp4: basics * mxfp4: Zen4 GEMM * mxfp4: repacked GEMM (AVX2/Zen4) * mxfp4: AVX2 GEMM * mxfp4: NEON GEMM * mxfp4: repacked GEMM (NEON) * mxfp4: Metal * Fix quantized K cache without FA (#680) * Prevent assert with quantized K cache and no FA * Fix MMQ when running with quantized K cache without FA --------- Co-authored-by: Iwan Kawrakow * Fix for Deepseek r1 parsing (#676) * Implement function calling / tools for ik_llama.cpp for Kimi K2 * Implement basic tool choice * Backport llama.cpp tool calls support * Enhance function calls with improved chat parser and string utilities - Add new chat.h/chat.cpp and chat-parser.h/chat-parser.cpp for better chat handling - Improve function calls parsing with fallback to llama.cpp builder pattern - Add string utility functions (starts_with, ends_with, find_partial_stop) - Update README with function calls testing instructions - Enhance Kimi K2 parser and function calls documentation - Add comprehensive test suite for function calls - Update CMakeLists.txt and Makefile for new components * Enhance function calling with unified streaming and parser improvements - Fix streaming content cleanup to prevent function syntax in output - Unify content extraction patterns with llama.cpp approach - Improve Kimi K2 parser robustness and partial content handling - Add comprehensive test coverage for function call scenarios - Optimize chat message parsing and diff computation * Replace hardcoded values in kimi_k2_parser.hpp with named constants - Add compile-time constants for all token format markers - Add compile-time constants for XML format markers - Add compile-time constants for simple format patterns - Replace all hardcoded string literals with named constants - Use compile-time length calculation to avoid manual counting - Improve maintainability and reduce magic numbers throughout parser * Fix duplicate common_chat_parse definition - Remove duplicate implementation from chat-parser.cpp - Keep single implementation in chat.cpp following llama.cpp patterns - Resolves linker error: multiple definition of common_chat_parse * Fix JSON assertion failure in function call parsing - Add proper validation that 'function' field is an object before accessing nested keys - Handle missing 'arguments' field gracefully with default "{}" - Prevents crash when parsing malformed tool call JSON structures * Add comprehensive Qwen3 XML tool calling support with unit tests - Implement Qwen3 XML parser with {"name": "func", "arguments": {...}} format - Add model detection and routing for Qwen3 vs Kimi-K2 formats - Create 8 comprehensive unit tests covering parsing, streaming, error handling - Fix token format cleaning bug in kimi_k2_parser.hpp processing order - Remove progressive parsing code and related utilities - Add tool injection support for Qwen3 format in server utils * Add DeepSeek R1 function calling support with comprehensive unit tests - Implement complete DeepSeek R1 tool call parsing in common_chat_parser.cpp - Add DeepSeek R1 model detection and tool injection in deepseek_r1_tools.hpp - Update function_calls.hpp with DeepSeek R1 integration and content extraction - Update documentation to reflect support for Kimi-K2, Qwen3, and DeepSeek R1 models - Add comprehensive unit tests for DeepSeek R1 reasoning, tool calls, and integration - Port exact implementation patterns from original llama.cpp for compatibility Key features: - Native DeepSeek R1 format: <|tool▁calls▁begin|>function<|tool▁sep|>name```json{}```<|tool▁call▁end|><|tool▁calls▁end|> - Reasoning content extraction from ... tags - Multiple tool calls support with separate call blocks - Model detection for deepseek-r1, deepseek_r1 naming patterns - Integration with incremental parsing and streaming support * Add partial parsing support for JSON and regex - json-partial.h/cpp: JSON partial parsing functionality - regex-partial.h/cpp: Regex partial parsing functionality * Add format_chat integration tests for Qwen3 tool injection - Add test_qwen3_format_chat_integration() to validate tool injection pipeline - Test tool injection conditions and system message enhancement - Verify JSON formatting and anti-preamble instructions - Add comprehensive test documentation Tests confirm tool injection works correctly - conversational preamble issue is not in ik_llama.cpp but likely in UI configuration. * Fix Qwen3 tool call parsing - pass model name to parser Server was not passing model name to parse_chat_message_incremental(), causing Qwen3 to fall back to Kimi-K2 parser and return tool calls as content instead of proper tool_calls array. * Fix non-streaming path to use model-specific parsing Non-streaming responses were hardcoded to use Kimi-K2 format, causing Qwen3 XML tool calls to be returned as content instead of proper tool_calls array. Now uses same model detection as streaming path for consistency. * Update Qwen3 function call handling in server and tests - Enhanced server function call detection and response formatting - Improved test coverage for Qwen3 tool call scenarios - Refined XML parsing for better tool execution support * Add DeepSeek-R1 function call parsing support Implements comprehensive parsing for all 4 DeepSeek-R1 function call formats: - Format 1: Standard function call syntax (already supported) - Format 2: Alternative function call patterns (already supported) - Format 3: Tools array format - function\n```json\n{"tools": [...]} - Format 4: XML wrapped format - functionName\n```json\n{...}``` Key changes: - Added parse_deepseek_r1_tools_array() following original parse_prefixed_json_tool_call_array pattern - Added parse_deepseek_r1_xml_wrapped() following Hermes-2-Pro XML wrapper patterns - Integrated both parsers into exception handling chain for robust fallback - Added comprehensive TDD test coverage for all formats - Anonymized all confidential information while preserving functionality Resolves tool_calls_count=0 issue where DeepSeek-R1 models generated valid tool calls but server failed to parse them correctly. * Update function_calls.md documentation for DeepSeek-R1 Format 4 - Added Format 4 (XML wrapped) documentation with examples - Updated implementation notes with correct parser order (3→4→1→2) - Marked all DeepSeek-R1 formats as working (July 2025 update) - Updated test status for Format 3 and 4 as passing - Added parse_deepseek_r1_xml_wrapped() function reference - Corrected implementation file line numbers * Fix merge conflict in test-function-calls.cpp - Removed incomplete merge conflict marker from line 3027 - Ensured all tests compile and pass successfully - All DeepSeek-R1 formats (1-4) working correctly - All streaming and content cleaning tests passing * Fix DeepSeek R1 parsing issue with responses wrapped in think tags Restore missing consume_rest() call from working PR #648 implementation. When responses don't contain tool calls, remaining content after reasoning parsing must be preserved as displayable content. Fixes issue where entire responses wrapped in tags resulted in empty content output. * Implement proper reasoning handling following original llama.cpp patterns - Add missing reasoning_format and reasoning_in_content fields to common_chat_syntax - Update try_parse_reasoning to match original llama.cpp logic exactly - Add TDD test case with reasoning_in_content=true for DeepSeek R1 - Following TDD: test should now pass with proper syntax configuration Based on original llama.cpp implementation patterns. * TDD SUCCESS: Fix DeepSeek R1 thinking tag termination issue ✅ Test passes with reasoning_in_content=true configuration - Content properly preserved: 'content' displays fully - Reasoning field empty as expected - Following TDD: test-first approach validates the fix Next: Update server to automatically apply this configuration. * Complete server integration fix for DeepSeek R1 thinking tag termination - Server now automatically sets reasoning_in_content=true for DeepSeek R1 models - Fixes issue where responses wrapped in tags appear empty to users * Add TDD test case for DeepSeek R1 thinking tag termination issue - Test reproduces the exact failure scenario reported by user - Validates that reasoning_in_content=true fixes the issue - Demonstrates empty content problem and working solution * Add remaining TDD test changes for DeepSeek R1 thinking tag fix * Add debug output after upstream merge * Remove temporary benchmark and debug files - Remove tests/benchmark-progressive-parsing.cpp (development tool, not part of core functionality) - Remove tests/reproduce_bug.sh (debugging script, not needed for PR) * Port cpu moe options from mainline (#672) * Port cpu moe options from mainline * Use strdup and int32_t to follow coding guidelines * maxfp4: CUDA dequantize * mxfp4: CUDA GEMV * mxfp4: CUDA MMQ * mxfp4: minor CUDA tweaks --------- Co-authored-by: Iwan Kawrakow Co-authored-by: Anton Sokolchenko Co-authored-by: Parsa <61601745+TheLegendOfKitty@users.noreply.github.com> --- examples/quantize/quantize.cpp | 1 + ggml/include/ggml.h | 8 +- ggml/src/ggml-common.h | 18 ++ ggml/src/ggml-cuda.cu | 1 + ggml/src/ggml-cuda/common.cuh | 7 + ggml/src/ggml-cuda/convert.cu | 32 +++ ggml/src/ggml-cuda/mmq.cu | 4 + ggml/src/ggml-cuda/mmq.cuh | 72 +++++++ ggml/src/ggml-cuda/mmvq.cu | 14 ++ .../template-instances/mmq-instance-iq4_nl.cu | 3 +- ggml/src/ggml-cuda/vecdotq.cuh | 35 ++++ ggml/src/ggml-impl.h | 18 ++ ggml/src/ggml-metal.m | 32 ++- ggml/src/ggml-metal.metal | 154 +++++++++++++++ ggml/src/ggml-quants.c | 1 + ggml/src/ggml.c | 30 ++- ggml/src/iqk/iqk_gemm_legacy_quants.cpp | 183 +++++++++++++++--- ggml/src/iqk/iqk_mul_mat.cpp | 7 + ggml/src/iqk/iqk_quantize.cpp | 141 ++++++++++++++ ggml/src/iqk/iqk_quantize.h | 6 + include/llama.h | 1 + src/llama.cpp | 3 + 22 files changed, 733 insertions(+), 38 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 3de7bc20..2e2c62bf 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -28,6 +28,7 @@ static const std::vector QUANT_OPTIONS = { { "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", }, { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", }, { "Q6_0", LLAMA_FTYPE_MOSTLY_Q6_0, " 6.5 bpw quantization", }, + { "MXFP4", LLAMA_FTYPE_MOSTLY_MXFP4, " 4.25 bpw 4-bit float quantization",}, { "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", }, { "IQ2_XXS_R4",LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4,"IQ2_XXS repacked", }, { "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5b90c9a5..2c261392 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -403,6 +403,7 @@ extern "C" { GGML_TYPE_Q4_0_4_4 = 31, GGML_TYPE_Q4_0_4_8 = 32, GGML_TYPE_Q4_0_8_8 = 33, + GGML_TYPE_MXFP4 = 39, // so we are compatible with mainline // // So we are able to consume MS BitNet I2_S quants // @@ -507,9 +508,10 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors + GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors, using 26 to be compatible with mainline + GGML_FTYPE_MOSTLY_Q4_0_4_4 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0_4_8 = 27, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0_8_8 = 28, // except 1d tensors // GGML_FTYPE_MOSTLY_Q6_0 = 127, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_BN = 128, // except 1d tensors diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 1dc1ff6e..59f7ae71 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -158,6 +158,9 @@ typedef sycl::half2 ggml_half2; #define QI1_BN (QK_IQ1BN / (4*QR1_BN)) #define QR1_BN 8 +#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4)) +#define QR_MXFP4 2 + #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP #define QK4_0 32 @@ -174,6 +177,15 @@ typedef struct { } block_q4_1; static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); +// This is unfortunate (block is 17 bytes, so not even a 2-byte alignment) +// But to be able to use MXFP4-quantized models from mainline, we do the same. +#define QK_MXFP4 32 +typedef struct { + uint8_t e; // E8M0 + uint8_t qs[QK_MXFP4/2]; +} block_mxfp4; +static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding"); + #define QK5_0 32 typedef struct { ggml_half d; // delta @@ -2211,5 +2223,11 @@ GGML_TABLE_BEGIN(int8_t, iq6nl_values, 128) 48, 52, 56, 60, 64, 69, 73, 78, 83, 88, 93, 99, 104, 110, 116, 122, GGML_TABLE_END() +// e2m1 values (doubled) +// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16) + 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12, +GGML_TABLE_END() + #endif // GGML_COMMON_IMPL #endif // GGML_COMMON_IMPL diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 67d9828c..9372c05c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3498,6 +3498,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_KL: case GGML_TYPE_IQ3_KS: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 15485f60..c856a44b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -550,6 +550,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_NL; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_NL; + static constexpr int qr = QR4_NL; + static constexpr int qi = QI4_NL; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 8c03ae1b..689613f5 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -736,6 +736,27 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst } } +template +static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 }; + const int64_t i = blockIdx.x; + const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK4_NL); + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[ib].qs + 4*il; + union { float f; uint32_t u; } helper; + helper.u = x[ib].e >= 2 ? uint32_t(x[ib].e - 1) << 23u : uval[x[ib].e]; + const float d = helper.f; + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]; + y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]; + } +} + template static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int64_t i = blockIdx.x; @@ -1611,6 +1632,13 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq4_nl<<>>(vx, y); } +template +static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_mxfp4<<>>(vx, y); +} + template static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1943,6 +1971,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ4_KS: @@ -2044,6 +2074,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ4_KS: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 1e3accf0..bebc7c87 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -94,6 +94,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_NL: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_MXFP4: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ2_KL: mul_mat_q_case(ctx, args, stream); break; @@ -210,6 +213,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ2_KL: case GGML_TYPE_IQ3_KS: case GGML_TYPE_IQ4_KSS: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 20277041..9adc94a6 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -84,6 +84,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: @@ -204,6 +205,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ1_S_R4: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_MXFP4 : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ3_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KSS : return MMQ_DP4A_TXS_Q8_0; @@ -263,6 +265,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ1_S_R4: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_MXFP4 : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ3_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KSS : return MMQ_MMA_TILE_X_K_Q8_0; @@ -2078,6 +2081,67 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_mxfp4( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = threadIdx.x / QI4_NL; + const int kqsx = threadIdx.x % QI4_NL; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbx; + + const int aux_q4 = get_int_b1(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); + const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + + union { float f; uint32_t u; } helper; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) { + int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbxd; + helper.u = bxi->e ? uint32_t(bxi->e) << 23u : 0x00400000; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.5f * helper.f; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = 0.5f * helper.f; +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_iq2_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -3624,6 +3688,13 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; @@ -4164,6 +4235,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KS); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 012b3e5e..10d16aeb 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -31,6 +31,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) case GGML_TYPE_IQ1_S : return vec_dot_iq1_s_q8_1; case GGML_TYPE_IQ1_M : return vec_dot_iq1_m_q8_1; case GGML_TYPE_IQ4_NL : return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_MXFP4 : return vec_dot_mxfp4_q8_1; case GGML_TYPE_IQ4_XS : return vec_dot_iq4_xs_q8_1; case GGML_TYPE_IQ3_S : return vec_dot_iq3_s_q8_1; default : return nullptr; @@ -56,6 +57,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { case GGML_TYPE_IQ3_XXS : return VDR_IQ3_XXS_Q8_1_MMVQ; case GGML_TYPE_IQ3_S : return VDR_IQ3_S_Q8_1_MMVQ; case GGML_TYPE_IQ4_NL : return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_MXFP4 : return VDR_MXFP4_Q8_1_MMVQ; case GGML_TYPE_IQ4_XS : return VDR_IQ4_XS_Q8_1_MMVQ; default : return 1; } @@ -417,6 +419,14 @@ static void mul_mat_vec_iq4_nl_q8_1_cuda( mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +static void mul_mat_vec_mxfp4_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + static void mul_mat_vec_iq4_xs_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, @@ -509,6 +519,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ4_NL: mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_MXFP4: + mul_mat_vec_mxfp4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; case GGML_TYPE_IQ4_XS: mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; @@ -686,6 +699,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KL: diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu index eb02fab0..c88946c2 100644 --- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu @@ -1,5 +1,4 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - #include "../mmq.cuh" DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +DECL_MMQ_CASE(GGML_TYPE_MXFP4); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index cae5e04f..97a792bd 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -17,6 +17,15 @@ static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32 return x32; } +static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) { + const uint8_t * x8 = (const uint8_t *)x; + + int x32 = x8[4*i32 + 0] | (x8[4*i32 + 1] << 8); + x32 |= (x8[4*i32 + 2] | (x8[4*i32 + 3] << 8)) << 16; + + return x32; +} + static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) { return ((const int *) x)[i32]; // assume at least 4 byte alignment } @@ -1167,6 +1176,32 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( return d * sumi; } +#define VDR_MXFP4_Q8_1_MMVQ 2 +#define VDR_MXFP4_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_mxfp4_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx; + + const int * q8 = (const int *) bq8_1->qs + iqs; + + int2 sumi = {0, 0}; +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) { + const int aux_q4 = get_int_b1(bq4->qs, iqs + l); + const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); + + sumi.x = ggml_cuda_dp4a(v.x, q8[l + 0], sumi.x); + sumi.y = ggml_cuda_dp4a(v.y, q8[l + 4], sumi.y); + } + + union { float f; uint32_t u; } helper; + helper.u = bq4->e ? uint32_t(bq4->e) << 23u : 0x00400000; + + return 0.5f * helper.f * __low2float(bq8_1->ds) * (sumi.x + sumi.y); +} + #define VDR_IQ4_XS_Q8_1_MMVQ 4 #define VDR_IQ4_XS_Q8_1_MMQ 4 diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index e4e36860..62f07e1e 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -29,6 +29,24 @@ #endif +// Does not handle NaN +static inline float ggml_e8m0_to_fp32(uint8_t x) { + union { float f; uint32_t u; } helper; + helper.u = x ? (uint32_t)x << 23u : 0x00400000; + return helper.f; +} + +// As above, but returns ggml_e8m0_to_fp32(x)/2 +static inline float ggml_e8m0_to_fp32_half(uint8_t x) { + static uint32_t val[2] = { 0x00200000, 0x00400000 }; + union { float f; uint32_t u; } helper; + helper.u = x >= 2 ? (uint32_t)(x - 1) << 23u : val[x]; + return helper.f; +} + +#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x) +#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x) + /** * Converts brain16 to float32. * diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index a86c66b6..104ad664 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -105,6 +105,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, + GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, @@ -153,6 +154,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, @@ -195,6 +197,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, @@ -234,6 +237,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, @@ -273,6 +277,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, @@ -312,6 +317,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, @@ -767,6 +773,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS, get_rows_iq3_ks, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, get_rows_iq4_ks, true); @@ -815,6 +822,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KS_F32, mul_mv_iq3_ks_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, mul_mv_iq4_ks_f32, ctx->support_simdgroup_reduction); @@ -857,6 +865,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KS_F32, mul_mv_id_iq3_ks_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, mul_mv_id_iq4_ks_f32, ctx->support_simdgroup_reduction); @@ -896,6 +905,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32, mul_mm_iq3_ks_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, mul_mm_iq4_ks_f32, ctx->support_simdgroup_mm); @@ -935,6 +945,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, mul_mm_iq1_bn_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, mul_mm_iq2_bn_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, mul_mm_iq4_nl_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F16, mul_mm_mxfp4_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, mul_mm_iq4_xs_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16, mul_mm_iq3_ks_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, mul_mm_iq4_ks_f16, ctx->support_simdgroup_mm); @@ -974,6 +985,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F32, mul_mm_id_mxfp4_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32, mul_mm_id_iq3_ks_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, mul_mm_id_iq4_ks_f32, ctx->support_simdgroup_mm); @@ -2192,6 +2204,7 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32 ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break; @@ -2236,6 +2249,7 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16 ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F16 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16 ].pipeline; break; case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16 ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16 ].pipeline; break; @@ -2450,6 +2464,12 @@ static void ggml_metal_encode_node( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; } break; + case GGML_TYPE_MXFP4: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline; + } break; case GGML_TYPE_IQ4_XS: { nth0 = 4; @@ -2595,7 +2615,7 @@ static void ggml_metal_encode_node( } else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| - src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS) { + src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS || src0t == GGML_TYPE_MXFP4) { const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ5_KS ? 64*sizeof(float) : 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; @@ -2690,6 +2710,7 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32 ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; @@ -2888,6 +2909,12 @@ static void ggml_metal_encode_node( nth1 = 2; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; } break; + case GGML_TYPE_MXFP4: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline; + } break; case GGML_TYPE_IQ4_XS: { nth0 = 32; @@ -3044,7 +3071,7 @@ static void ggml_metal_encode_node( } else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| - src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS) { + src0t == GGML_TYPE_IQ4_KSS || src0t == GGML_TYPE_IQ5_KS || src0t == GGML_TYPE_MXFP4) { const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ5_KS ? 64*sizeof(float) : 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; @@ -3095,6 +3122,7 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 53de59dd..f700e6f7 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3975,6 +3975,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = { -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f }; +constexpr constant static float kvalues_mxfp4_f[16] = { + 0.f, 1.f, 2.f, 3.f, 4.f, 6.f, 8.f, 12.f, 0.f, -1.f, -2.f, -3.f, -4.f, -6.f, -8.f, -12.f +}; + constexpr constant static float kvalues_iq4k_f[32] = { -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f, -123.f, -100.f, -79.f, -61.f, -45.f, -31.f, -18.f, -6.f, 5.f, 17.f, 29.f, 42.f, 57.f, 73.f, 93.f, 117.f, @@ -6082,6 +6086,104 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } +void kernel_mul_mv_mxfp4_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK4_NL; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_mxfp4 * x = (device const block_mxfp4 *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0 or 1 + + shared_values[tiisg] = kvalues_mxfp4_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + constexpr uint32_t val[2] = { 0x00200000, 0x00400000 }; + union { float f; uint32_t u; } helper; + + for (int ib = ix; ib < nb; ib += 16) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + + device const block_mxfp4 & xb = x[row*nb + ib]; + device const uint8_t * q4 = (device const uint8_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 8) | (q4[2] << 16) | (q4[3] << 24); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[4] | (q4[5] << 8) | (q4[6] << 16) | (q4[7] << 24); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + helper.u = xb.e >= 2 ? (uint32_t)(xb.e - 1) << 23u : val[xb.e]; + sumf[row] += helper.f * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 16 * QK4_NL; + } + + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, @@ -8129,6 +8231,35 @@ kernel void kernel_mul_mv_iq4_nl_f32( kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } +[[host_name("kernel_mul_mv_mxfp4_f32")]] +kernel void kernel_mul_mv_mxfp4_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_mxfp4_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( device const void * src0, @@ -8791,6 +8922,24 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 } } +template +void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) { + constexpr uint32_t val[2] = { 0x00200000, 0x00400000 }; + device const uint8_t * q4 = (device const uint8_t *)xb->qs; + union { float f; uint32_t u; } helper; + helper.u = xb->e >= 2 ? (uint32_t)(xb->e - 1) << 23u : val[xb->e]; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = q4[4*i] | (q4[4*i+1] << 8) | (q4[4*i+2] << 16) | (q4[4*i+3] << 24); + aux32 = (aux32 >> 4*il) & 0x0f0f0f0f; + reg[i][0] = helper.f * kvalues_mxfp4_f[q8[0]]; + reg[i][1] = helper.f * kvalues_mxfp4_f[q8[1]]; + reg[i][2] = helper.f * kvalues_mxfp4_f[q8[2]]; + reg[i][3] = helper.f * kvalues_mxfp4_f[q8[3]]; + } +} + template void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 @@ -9761,6 +9910,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq2_k")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq3_k")]] kernel get_rows_q_t kernel_get_rows_q; @@ -9810,6 +9960,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm, float>; template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm, float>; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mat_mm_t kernel_mul_mm, float>; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm, float>; template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; @@ -9850,6 +10001,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mat_mm_t kernel_mul_mm, half>; template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mat_mm_t kernel_mul_mm, half>; template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mat_mm_t kernel_mul_mm, half>; template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mat_mm_t kernel_mul_mm, half>; template [[host_name("kernel_mul_mm_iq2_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; template [[host_name("kernel_mul_mm_iq3_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; @@ -9897,6 +10049,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; +template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_iq2_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; @@ -10126,6 +10279,7 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq3_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e49417af..7a14fcf2 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15418,6 +15418,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_MXFP4: break; case GGML_TYPE_Q6_0: break; case GGML_TYPE_IQ2_K: break; case GGML_TYPE_IQ2_KS: break; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5aec6b0d..f3a23727 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1301,14 +1301,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq4_nl, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, .vec_dot = ggml_vec_dot_iq4_nl_q8_0, -#if GGML_USE_IQK_MULMAT #if defined HAVE_FANCY_SIMD .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, -#endif -#else - .vec_dot_type = GGML_TYPE_Q8_0, #endif .nrows = 1, .row_meta_size = 0, @@ -1326,6 +1322,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_MXFP4] = { + .type_name = "mxfp4", + .blck_size = QK_MXFP4, + .type_size = sizeof(block_mxfp4), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp4, + .from_float = quantize_row_mxfp4, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref, + .vec_dot = vec_dot_mxfp4_q8_0_x4, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ4_KS] = { .type_name = "iq4_ks", .blck_size = QK_K, @@ -4609,6 +4622,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q6_0_R4: wtype = GGML_TYPE_Q6_0_R4; break; case GGML_FTYPE_MOSTLY_Q8_0_R8: wtype = GGML_TYPE_Q8_0_R8; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; + case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break; case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break; case GGML_FTYPE_MOSTLY_IQ4_KS_R4: wtype = GGML_TYPE_IQ4_KS_R4;break; case GGML_FTYPE_MOSTLY_IQ5_KS_R4: wtype = GGML_TYPE_IQ5_KS_R4;break; @@ -11388,6 +11402,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -11868,6 +11883,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -12045,6 +12061,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -15549,6 +15566,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -15966,6 +15984,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -16289,6 +16308,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -16929,6 +16949,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q6_0_R4: case GGML_TYPE_I2_S: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: @@ -24005,6 +24026,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_0_R4: result = quantize_q5_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_0_R4: result = quantize_q6_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0_R8: result = quantize_q8_0_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS_R4:result = quantize_iq4_ks_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index ab6eb130..03128319 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -1,4 +1,5 @@ #include "iqk_gemm_legacy_quants.h" +#include #ifdef IQK_IMPLEMENT @@ -105,6 +106,21 @@ struct ScaleHelperQ_0 { template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } }; +struct ScaleHelperQ_0_MXFP4 { + float scales[4]; + template + inline __m128 prepare4(const Q * y) { + for (int j = 0; j < 4; ++j) scales[j] = GGML_E8M0_TO_FP32_HALF(y[j].e); + return _mm_loadu_ps(scales); + } + template + inline __m128 prepare4(__m128 other_scales, const Q * y) { + return _mm_mul_ps(other_scales, prepare4(y)); + } + template inline float prepare1(const Q * y) const { return GGML_E8M0_TO_FP32_HALF(y->e); } + template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } +}; + template struct ScaleHelperQ_0_1 { ggml_half scales8[4]; @@ -128,28 +144,28 @@ struct ScaleHelperQ_0_1 { const __m128 min = _mm_set1_ps(float(-min_value)); }; -//template -//struct ScaleHelperQ_0_2 { -// ggml_bf16_t scales8[4]; -// template -// inline __m256 prepare4(const Q * y) { -// for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; -// auto s4 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales8)), 16)); -// return _mm256_set_m128(_mm_mul_ps(s4, min), s4); -// } -// template -// inline __m256 prepare4(__m256 other_scales, const Q * y) { -// return _mm_mul256_ps(other_scales, prepare4(y)); -// } -// template inline std::pair prepare1(const Q * y) const { -// float d = GGML_BF16_TO_FP32(y->d); -// return std::make_pair(d, -d*float(min_value)); -// } -// std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { -// return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); -// } -// const __m128 min = _mm_set1_ps(float(-min_value)); -//}; +template +struct ScaleHelperQ_0_1_MXFP4 { + float scales[4]; + template + inline __m256 prepare4(const Q * y) { + for (int j = 0; j < 4; ++j) scales[j] = GGML_E8M0_TO_FP32_HALF(y[j].e); + auto s4 = _mm_loadu_ps(scales); + return _mm256_set_m128(_mm_mul_ps(s4, min), s4); + } + template + inline __m256 prepare4(__m256 other_scales, const Q * y) { + return _mm_mul256_ps(other_scales, prepare4(y)); + } + template inline std::pair prepare1(const Q * y) const { + float d = GGML_E8M0_TO_FP32_HALF(y->e); + return std::make_pair(d, -d*float(min_value)); + } + std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); + } + const __m128 min = _mm_set1_ps(float(-min_value)); +}; struct ScaleHelperQ8_1 { template @@ -553,6 +569,49 @@ struct IQ4_NL0_Dequantizer { } }; +//============================= +static inline __m128i load_unsigned_mxfp4_values_128() { + static const uint8_t kvalues_mxfp4_unsigned[16] = {12, 13, 14, 15, 16, 18, 20, 24, 12, 11, 10, 9, 8, 6, 4, 0}; + return _mm_loadu_si128((const __m128i *)kvalues_mxfp4_unsigned); +} + +static inline __m256i load_unsigned_mxfp4_values_256() { + auto val128 = load_unsigned_mxfp4_values_128(); + return MM256_SET_M128I(val128, val128); +} + +#ifdef HAVE_FANCY_SIMD +static inline __m512i load_unsigned_mxfp4_values_512() { + auto val256 = load_unsigned_mxfp4_values_256(); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); +} +#endif + +static inline __m128i load_mxfp4_values_128() { + return _mm_loadu_si128((const __m128i *)kvalues_mxfp4); +} + +static inline __m256i load_mxfp4_values_256() { + auto val128 = load_mxfp4_values_128(); + return MM256_SET_M128I(val128, val128); +} + +struct MXFP4_Dequantizer { + Dequantizer4bit b4; + const __m256i values = load_unsigned_mxfp4_values_256(); + inline __m256i dequant(const block_mxfp4 * x) const { + return _mm256_shuffle_epi8(values, b4.dequant(x->qs)); + } +}; + +struct MXFP40_Dequantizer { + Dequantizer4bit b4; + const __m256i values = load_mxfp4_values_256(); + inline __m256i dequant(const block_mxfp4 * x) const { + return _mm256_shuffle_epi8(values, b4.dequant(x->qs)); + } +}; + struct Q4_1_Dequantizer { Dequantizer4bit b4; inline __m256i dequant(const block_q4_1 * x) const { @@ -665,6 +724,11 @@ struct Q4_0_1_Unpacker final : public Q_Unpacker using Sum4T = Sum4q4; inline static int block_size() { return QK4_0; } }; +struct MXFP4_Unpacker final : public Q_Unpacker, MXFP4_Dequantizer> { + MXFP4_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ82; + inline static int block_size() { return QK4_NL; } +}; #ifdef HAVE_FANCY_SIMD struct IQ4_NL_Unpacker final : public Q_Unpacker, IQ4_NL_Dequantizer> { IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} @@ -672,7 +736,7 @@ struct IQ4_NL_Unpacker final : public Q_Unpacker { +struct IQ4_NL_Unpacker final : public Q_Unpacker { IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; inline static int block_size() { return QK4_NL; } @@ -1757,7 +1821,11 @@ void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc for (int i = 0; i < nb; ++i) { for (int k = 0; k < 8; ++k) { - y[i].d[k] = x8[k][i].d; + if constexpr (std::is_same_v) { + y[i].d[k] = GGML_FP32_TO_FP16(GGML_E8M0_TO_FP32_HALF(x8[k][i].e)); + } else { + y[i].d[k] = x8[k][i].d; + } _mm256_storeu_si256((__m256i *)block, deq.dequant(x8[k] + i)); auto qs = (uint32_t *)y[i].qs; for (int l = 0; l < 4; ++l) { @@ -1819,7 +1887,8 @@ template void set_functions(std::array || std::is_same_v || - std::is_same_v || std::is_same_v) { + std::is_same_v || std::is_same_v || + std::is_same_v) { IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs) } } @@ -1835,6 +1904,7 @@ bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_MXFP4 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; default: return false; } return true; @@ -1878,6 +1948,12 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array(kernels); +//#ifndef HAVE_FANCY_SIMD +// expected_typeB = GGML_TYPE_Q8_0_X4; +//#endif + break; case GGML_TYPE_Q4_0_R8: IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_0_r8_q8_2, kernels) #ifdef HAVE_FANCY_SIMD @@ -2039,7 +2115,7 @@ template struct Q80 { template inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { deq.prepare1(i); - float d = GGML_FP16_TO_FP32(deq.x[i].d); + float d = deq.block_scale(i); for (int iy = 0; iy < nrc; ++iy) { auto q8b = vld1q_s8_x2(y[iy][i].qs); auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); @@ -2147,6 +2223,8 @@ struct DequantizerQ40 final : public BaseLegacyDequantizer { return vld1_f16((const float16_t *)aux); } + inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); } + const int8x16_t m8 = vdupq_n_s8(-8); //ggml_half aux[4]; }; @@ -2174,6 +2252,7 @@ struct DequantizerQ60 final : public BaseLegacyDequantizer { } return vld1_f16((const float16_t *)aux); } + inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); } const int8x16_t m32 = vdupq_n_s8(-32); const uint8x16_t hmask = vdupq_n_u8(0x30); @@ -2204,6 +2283,36 @@ struct DequantizerIQ4NL final : public BaseLegacyDequantizer { static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; return vld1q_s8(iq4nl_values); } + inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); } + + const int8x16_t values = load_values(); +}; + +struct DequantizerMXFP4 final : public BaseLegacyDequantizer { + + DequantizerMXFP4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + q[0] = vqtbl1q_s8(values, q[0]); + q[1] = vqtbl1q_s8(values, q[1]); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + float aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = GGML_E8M0_TO_FP32_HALF(x[4*i+k].e); + prepare1(4*i+k, bits.b + 2*k); + } + return vcvt_f16_f32(vld1q_f32(aux)); + } + static int8x16_t load_values() { + return vld1q_s8(kvalues_mxfp4); + } + inline float block_scale(int i) const { return GGML_E8M0_TO_FP32_HALF(x[i].e); } const int8x16_t values = load_values(); }; @@ -2280,6 +2389,7 @@ struct DequantizerQ50 final : public BaseLegacyDequantizer { } return vld1_f16((const float16_t *)aux); } + inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); } HighBit5Legacy hbits; @@ -2305,6 +2415,7 @@ struct DequantizerQ80 final : public BaseLegacyDequantizer { } return vld1_f16((const float16_t *)aux); } + inline float block_scale(int i) const { return GGML_FP16_TO_FP32(x[i].d); } }; @@ -2877,6 +2988,16 @@ struct DeqIQ4NL { static inline int8x16_t load_values() { return vld1q_s8(iq4k_values); } }; +struct DeqMXFP4 { + const int8x16_t mt = load_values(); + const uint8x16_t ml = vdupq_n_s8(0xf); + inline int8x16x2_t dequant(const block_mxfp4& x) const { + auto bits = vld1q_u8(x.qs); + return { vqtbl1q_s8(mt, vandq_u8(bits, ml)), vqtbl1q_s8(mt, vshrq_n_u8(bits, 4)) }; + } + static inline int8x16_t load_values() { return vld1q_s8(kvalues_mxfp4); } +}; + struct DeqQ50 { inline int8x16x2_t dequant(const block_q5_0& x) const { @@ -2953,7 +3074,11 @@ void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc for (int i = 0; i < nb; ++i) { for (int k = 0; k < 8; ++k) { - y[i].d[k] = x8[k][i].d; + if constexpr (std::is_same_v) { + y[i].d[k] = GGML_FP32_TO_FP16(GGML_E8M0_TO_FP32_HALF(x8[k][i].e)); + } else { + y[i].d[k] = x8[k][i].d; + } vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i])); auto qs = (uint32_t *)y[i].qs; for (int l = 0; l < 4; ++l) { @@ -3011,6 +3136,7 @@ bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_MXFP4 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; case GGML_TYPE_Q8_0 : iqk_convert_qX_q80_r8(n, vx, bx, vy, nrc_x); break; default: return false; } @@ -3049,6 +3175,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q5_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; +#ifdef HAVE_FANCY_SIMD case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; +#endif + case GGML_TYPE_MXFP4 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ1_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ2_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type; @@ -295,6 +298,7 @@ struct MulMat { case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_MXFP4 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ1_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; @@ -458,6 +462,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: //case GGML_TYPE_Q4_0_R8: //case GGML_TYPE_Q5_0_R4: //case GGML_TYPE_Q6_0_R4: @@ -871,6 +876,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_Q6_0_R4: case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_MXFP4: return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: @@ -960,6 +966,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_Q8_0_R8: case GGML_TYPE_Q8_1: case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_MXFP4: return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index ece0b734..184a1aee 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3697,6 +3697,147 @@ void quantize_row_q8_K128(const float * x, void * vy, int64_t k) { iqk_quantize_row_q8_K128(x, vy, k); } +// ============================== MXFP4 + +namespace { +inline int best_index_mxfp4(float d, const int8_t * values, float x) { + float best = std::abs(x - d*values[0]); + int index = 0; + for (int j = 1; j < 16; ++j) { + float diff = std::abs(x - d*values[j]); + if (diff < best) { best = diff; index = j; } + } + return index; +} +static void quantize_row_mxfp4_impl(int n_per_row, const float * x, char * cy, + [[maybe_unused]] float * weight, + const int8_t * values, + [[maybe_unused]] const float * quant_weights, + [[maybe_unused]] const int ntry) { + + GGML_ASSERT(n_per_row % QK_MXFP4 == 0); + GGML_UNUSED(quant_weights); + + block_mxfp4 * y = (block_mxfp4 *)cy; + + //int last_ibl = -1; + //float sigma2 = 0; + + //const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127); + // -> log2f(amax) ~ e - 125 -> amax = 2^(e - 125) + //const float d = GGML_E8M0_TO_FP32_HALF(e); + + for (int ib = 0; ib < n_per_row/QK_MXFP4; ++ib) { + memset(&y[ib], 0, sizeof(block_mxfp4)); + const float * xb = x + ib*QK_MXFP4; + //if (int ibl = ib/(QK_K/QK_MXFP4); ibl != last_ibl) { + // int n = std::min(QK_K, n_per_row - ib*QK_MXFP4); + // float sumx2 = 0; + // for (int j = 0; j < n; ++j) sumx2 += xb[j]*xb[j]; + // sigma2 = 2.0f*sumx2/n; + // last_ibl = ibl; + //} + //if (quant_weights) { + // const float * qw = quant_weights + ib*QK_MXFP4; + // for (int j = 0; j < QK_MXFP4; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + //} else { + // for (int j = 0; j < QK_MXFP4; ++j) weight[j] = xb[j]*xb[j]; + //} + float amax = 0; + for (int j = 0; j < QK_MXFP4; ++j) { + float ax = fabsf(xb[j]); + amax = std::max(amax, ax); + } + if (!amax) { + continue; + } + const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127); + const float d = GGML_E8M0_TO_FP32_HALF(e); + y[ib].e = e; + for (int j = 0; j < QK_MXFP4/2; ++j) { + uint8_t v0 = best_index_mxfp4(d, values, xb[j]); + uint8_t v1 = best_index_mxfp4(d, values, xb[j+QK_MXFP4/2]); + y[ib].qs[j] = v0 | (v1 << 4); + } + } +} +} + +void quantize_row_mxfp4_ref(const float * x, block_mxfp4 * y, int64_t k) { + quantize_mxfp4(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_mxfp4(const float * x, void * y, int64_t k) { + quantize_mxfp4(x, (void *)y, 1, k, nullptr); +} + +size_t quantize_mxfp4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + constexpr int kBlockSize = QK_MXFP4; + GGML_ASSERT(n_per_row%kBlockSize == 0); + auto row_size = ggml_row_size(GGML_TYPE_MXFP4, n_per_row); + char * qrow = (char *)dst; + float weight[kBlockSize]; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_mxfp4_impl(n_per_row, src, qrow, weight, kvalues_mxfp4, imatrix, 7); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_mxfp4(const block_mxfp4 * x, float * y, int64_t k) { + constexpr int kBlockSize = QK_MXFP4; + GGML_ASSERT(k%kBlockSize == 0); + int nblock = k/kBlockSize; + for (int ib = 0; ib < nblock; ++ib) { + float d = GGML_E8M0_TO_FP32_HALF(x[ib].e); + for (int j = 0; j < kBlockSize/2; ++j) { + y[j ] = d * kvalues_mxfp4[x[ib].qs[j] & 0xf]; + y[j+kBlockSize/2] = d * kvalues_mxfp4[x[ib].qs[j] >> 4]; + } + y += kBlockSize; + } +} + +void vec_dot_mxfp4_q8_0_x4(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_MXFP4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK_MXFP4 == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); + //const block_mxfp4 * x = (const block_mxfp4 *)vx; + //const block_q8_K * y = (const block_q8_K *)vy; + //int nblock = n/QK_MXFP4; + //float sumf = 0; + //for (int ibl = 0; ibl < nblock; ++ibl) { + // //int sumi = 0; + // auto qy = y[ibl].qs; + // auto qx = x[ibl].qs; + // float db = d * y[ibl].d; + // for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { + // float dl = db * ((x[ibl].scales[ib] & 254) - 127); + // //int ls = (x[ibl].scales[ib] & 254) - 127; + // const int8_t * values = iq4k_values + ((x[ibl].scales[ib] & 1) << 4); + // int suml = 0; + // for (int j = 0; j < kBlockSize/2; ++j) { + // suml += qy[j ] * values[qx[j] & 0xf] + // + qy[j + kBlockSize/2] * values[qx[j] >> 4]; + // } + // sumf += dl * suml; + // //sumi += ls * suml; + // qy += kBlockSize; + // qx += kBlockSize/2; + // } + // //sumf += d * y[ibl].d * sumi; + //} + //*s = sumf; +} + namespace { static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size, int n_per_row, const float * x, char * cy, diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 7d789fba..4ca7987a 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -67,6 +67,12 @@ size_t quantize_iq4_kss(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds void dequantize_row_iq4_kss(const block_iq4_kss * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq4_kss_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); +void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_mxfp4_q8_0_x4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_iq2_ks_ref(const float * GGML_RESTRICT x, block_iq2_ks * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq2_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/include/llama.h b/include/llama.h index 1bc1bdaf..0c26868e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -186,6 +186,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors + LLAMA_FTYPE_MOSTLY_MXFP4 = 38, // except 1d tensors, 38 to be compatible with mainline // LLAMA_FTYPE_MOSTLY_Q6_0 = 135, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_BN = 136, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 47e26a83..50b9ad5c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4538,6 +4538,7 @@ struct llama_model_loader { case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break; case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break; case GGML_TYPE_Q8_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R8; break; + case GGML_TYPE_MXFP4: ftype = LLAMA_FTYPE_MOSTLY_MXFP4; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4; break; @@ -5294,6 +5295,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_0_R4: return "Q5_0_R4 - 5.5 bpw"; case LLAMA_FTYPE_MOSTLY_Q6_0_R4: return "Q6_0_R4 - 6.5 bpw"; case LLAMA_FTYPE_MOSTLY_Q8_0_R8: return "Q8_0_R8 - 8.5 bpw"; + case LLAMA_FTYPE_MOSTLY_MXFP4: return "MXFP4 - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:return "IQ4_KS_R4 - 4.25 bpw"; @@ -20541,6 +20543,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_0_R4: default_type = GGML_TYPE_Q5_0_R4; break; case LLAMA_FTYPE_MOSTLY_Q6_0_R4: default_type = GGML_TYPE_Q6_0_R4; break; case LLAMA_FTYPE_MOSTLY_Q8_0_R8: default_type = GGML_TYPE_Q8_0_R8; break; + case LLAMA_FTYPE_MOSTLY_MXFP4: default_type = GGML_TYPE_MXFP4; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:default_type = GGML_TYPE_IQ4_KS_R4;break;