Attempt fix 4

This commit is contained in:
Saood Karim
2025-04-20 01:09:52 -05:00
parent 5c2380a55b
commit b79a92cb29
3 changed files with 14 additions and 16 deletions

View File

@@ -29,7 +29,7 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
// TODO: get the ggml_type enum here without polution // TODO: get the ggml_type enum here without polution
// //
GGML_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, __attribute__ ((visibility ("default"))) bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
int neq3, int neq2, long nbq3, long nbq2, int neq3, int neq2, long nbq3, long nbq2,
int nek3, int nek2, long nbk3, long nbk2, int nek3, int nek2, long nbk3, long nbk2,
int nev3, int nev2, long nbv3, long nbv2, int nev3, int nev2, long nbv3, long nbv2,

View File

@@ -15,7 +15,6 @@
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "ggml.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-quants.h" #include "ggml-quants.h"
#include "iqk_mul_mat.h" #include "iqk_mul_mat.h"
@@ -404,7 +403,7 @@ private:
} }
GGML_API bool iqk_mul_mat(long Nx, long Ny, long ne00, __attribute__ ((visibility ("default"))) bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA, int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB, int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth) { float * C, long stride_C, int ith, int nth) {
@@ -441,7 +440,7 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
} }
} }
GGML_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, __attribute__ ((visibility ("default"))) bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
long ne02, long ne03, long ne12, long ne13, long ne02, long ne03, long ne12, long ne13,
long nb02, long nb03, long nb12, long nb13, long nb2, long nb3, long nb02, long nb03, long nb12, long nb13, long nb2, long nb3,
int typeA, const void * A, long strideA, int typeA, const void * A, long strideA,
@@ -546,7 +545,7 @@ GGML_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
return true; return true;
} }
GGML_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, __attribute__ ((visibility ("default"))) bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeA, const void * A, long strideA, int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB, int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
@@ -572,7 +571,7 @@ GGML_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
return true; return true;
} }
GGML_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, __attribute__ ((visibility ("default"))) bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA, int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB, int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
@@ -17551,11 +17550,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
#else // IQK_IMPLEMENT #else // IQK_IMPLEMENT
GGML_API bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { __attribute__ ((visibility ("default"))) bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
return false; return false;
} }
GGML_API bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/, __attribute__ ((visibility ("default"))) bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/,
long /*ne02*/, long /*ne03*/, long /*ne12*/, long /*ne13*/, long /*ne02*/, long /*ne03*/, long /*ne12*/, long /*ne13*/,
long /*nb02*/, long /*nb03*/, long /*nb12*/, long /*nb13*/, long /*nb2*/, long /*nb3*/, long /*nb02*/, long /*nb03*/, long /*nb12*/, long /*nb13*/, long /*nb2*/, long /*nb3*/,
int /*typeA*/, const void * /*A*/, long /*strideA*/, int /*typeA*/, const void * /*A*/, long /*strideA*/,
@@ -17564,12 +17563,12 @@ GGML_API bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/,
return false; return false;
} }
GGML_API bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long, __attribute__ ((visibility ("default"))) bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long,
const void *, int, int) { const void *, int, int) {
return false; return false;
} }
GGML_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/, int /*unary_op*/, __attribute__ ((visibility ("default"))) bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/, int /*unary_op*/,
int /*typeA*/, const void * /*Aup*/, const void * /*Agate*/, long /*strideA*/, int /*typeA*/, const void * /*Aup*/, const void * /*Agate*/, long /*strideA*/,
int /*typeB*/, const void * /*B*/, long /*strideB*/, int /*typeB*/, const void * /*B*/, long /*strideB*/,
float * /*C*/, long /*nb1*/, long /*nb2*/, const void * /*vrow_mapping*/, int /*ith*/, int /*nth*/) { float * /*C*/, long /*nb1*/, long /*nb2*/, const void * /*vrow_mapping*/, int /*ith*/, int /*nth*/) {

View File

@@ -7,36 +7,35 @@
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include <stdbool.h> #include <stdbool.h>
#include "ggml.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
GGML_API bool iqk_mul_mat(long Nx, long Ny, long ne00, __attribute__ ((visibility ("default"))) bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA, int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB, int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth); float * C, long stride_C, int ith, int nth);
GGML_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, __attribute__ ((visibility ("default"))) bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
long ne02, long ne03, long ne12, long ne13, long ne02, long ne03, long ne12, long ne13,
long nb02, long nb03, long nb12, long nb13, long nb2, long nb3, long nb02, long nb03, long nb12, long nb13, long nb2, long nb3,
int typeA, const void * A, long strideA, int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB, int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth); float * C, long stride_C, int ith, int nth);
GGML_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, __attribute__ ((visibility ("default"))) bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeA, const void * A, long strideA, int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB, int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
GGML_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, __attribute__ ((visibility ("default"))) bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA, int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB, int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
typedef void (*barrier_t) (void *); typedef void (*barrier_t) (void *);
GGML_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, __attribute__ ((visibility ("default"))) bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
int neq3, int neq2, long nbq3, long nbq2, int neq3, int neq2, long nbq3, long nbq2,
int nek3, int nek2, long nbk3, long nbk2, int nek3, int nek2, long nbk3, long nbk2,
int nev3, int nev2, long nbv3, long nbv2, int nev3, int nev2, long nbv3, long nbv2,