iqk_mul_mat: be independent of llamafile_sgemm (WIP)

* Remove iqk_mul_mat from llamafile_sgemm
* Pass tensor types and strides to iqk_mul_mat

It is marked WIP because only tested on __aarch64__
This commit is contained in:
Iwan Kawrakow
2024-06-11 09:12:22 +02:00
parent 3593891f39
commit ad53eabf87
5 changed files with 77 additions and 101 deletions

View File

@@ -120,7 +120,7 @@ struct MulMat {
funcs[n_left-1](n, vx, bx, info, nrc_x);
}
}
static bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny);
static bool set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
private:
template <typename Dequantizer> static void set_functions(MulMat& m);
};
@@ -173,43 +173,50 @@ const uint64_t keven_signs[128] = {
}
bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, const void * B,
bool iqk_mul_mat(int task_type, long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth) {
MulMat mm;
int row_size_q8;
if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) {
if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) {
return false;
}
auto row_size_qx = ggml_row_size((ggml_type)typeA, ne00);
if (ggml_task_type(task_type) != GGML_TASK_TYPE_COMPUTE) return ggml_task_type(task_type) == GGML_TASK_TYPE_INIT;
auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA));
auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB));
auto nrc_x = (Nx + nth - 1)/nth;
auto first_x = ith*nrc_x;
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, (size_t)row_size_q8, 0, 1, nullptr, 0};
DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
return true;
}
bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B,
bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
assert(row_mapping != nullptr);
MulMat mm;
int row_size_q8;
if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) {
if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) {
return false;
}
int row_size_qx = ggml_row_size((ggml_type)typeA, ne00);
auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA));
auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB));
int nrc_x = (Nx + nth - 1)/nth;
int first_x = ith*nrc_x;
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)};
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
return true;
}
@@ -236,7 +243,6 @@ inline float hsum_float_8(__m256 x) {
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
constexpr static int nrc_y = nrc;
@@ -2394,14 +2400,14 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
}
bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny) {
bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
//if (Ny == 1 && (typeA == GGML_TYPE_IQ3_S || typeA == GGML_TYPE_IQ3_XXS)) {
if (Ny == 999 && typeA == GGML_TYPE_IQ3_S) {
return false;
}
if (typeA == GGML_TYPE_F16) {
if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F32) {
for (auto& f : mm.funcs) f = nullptr;
mm.funcs[0] = mul_mat_fX_fY_T<1, ggml_half, float>;
mm.funcs[1] = mul_mat_fX_fY_T<2, ggml_half, float>;
@@ -2411,10 +2417,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
#ifndef __AVX512F__
mm.funcs[5] = mul_mat_fX_fY_T<6, ggml_half, float>;
#endif
row_size_q8 = ggml_row_size(GGML_TYPE_F32, ne00);
return true;
}
if (typeA == GGML_TYPE_F32) {
if (typeA == GGML_TYPE_F32 && typeB == GGML_TYPE_F16) {
for (auto& f : mm.funcs) f = nullptr;
mm.funcs[0] = mul_mat_fX_fY_T<1, float, ggml_half>;
mm.funcs[1] = mul_mat_fX_fY_T<2, float, ggml_half>;
@@ -2424,7 +2429,6 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
#ifndef __AVX512F__
mm.funcs[5] = mul_mat_fX_fY_T<6, float, ggml_half>;
#endif
row_size_q8 = ggml_row_size(GGML_TYPE_F16, ne00);
return true;
}
// Using the standard legacy quant template is slightly faster than tiling
@@ -2441,7 +2445,7 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
// return true;
// }
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
auto expected_typeB = GGML_TYPE_Q8_K;
switch (typeA) {
case GGML_TYPE_Q2_K:
@@ -2491,33 +2495,35 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
MulMat::set_functions<Q4_0_Unpacker>(mm);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
expected_typeB = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q4_1:
assert (ne00 % QK4_1 == 0);
MulMat::set_functions<Q4_1_Unpacker>(mm);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
expected_typeB = GGML_TYPE_Q8_1;
break;
case GGML_TYPE_Q5_0:
assert (ne00 % QK5_0 == 0);
MulMat::set_functions<Q5_0_Unpacker>(mm);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
expected_typeB = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q5_1:
assert (ne00 % QK5_1 == 0);
MulMat::set_functions<Q5_1_Unpacker>(mm);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
expected_typeB = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q8_0:
assert (ne00 % QK8_0 == 0);
MulMat::set_functions<Q8_0_Unpacker>(mm);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
expected_typeB = GGML_TYPE_Q8_0;
break;
default:
return false;
}
if (typeB != expected_typeB) return false;
return true;
}
@@ -3882,7 +3888,7 @@ IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0
template <int nrc_y>
void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QF16Base::k_step == 0);
GGML_ASSERT(n%QF16Base::k_step == 0);
constexpr int k_nx = 5;
const char * cx = (const char *)vx;
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
@@ -3933,10 +3939,10 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
}
bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /*Ny*/) {
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
if (typeA == GGML_TYPE_F16) {
if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) {
if (ne00%8) return false;
for (auto& f : m.funcs) f = nullptr;
m.funcs[0] = mul_mat_f16_f16_T<1>;
m.funcs[1] = mul_mat_f16_f16_T<2>;
@@ -3945,10 +3951,11 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
m.funcs[4] = mul_mat_f16_f16_T<5>;
//m.funcs[5] = mul_mat_f16_f16_T<6>;
//m.funcs[6] = mul_mat_f16_f16_T<7>;
row_size_q8 = ggml_row_size(GGML_TYPE_F16, ne00);
return true;
}
auto expected_Btype = GGML_TYPE_Q8_K;
switch (typeA) {
case GGML_TYPE_Q2_K:
MulMat::set_functions<DequantizerQ2K>(m);
@@ -3985,28 +3992,29 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q4_1:
MulMat::set_functions<DequantizerQ41>(m);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
expected_Btype = GGML_TYPE_Q8_1;
break;
case GGML_TYPE_Q5_0:
MulMat::set_functions<DequantizerQ50>(m);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q5_1:
MulMat::set_functions<DequantizerQ51>(m);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
expected_Btype = GGML_TYPE_Q8_1;
break;
case GGML_TYPE_Q8_0:
MulMat::set_functions<DequantizerQ80>(m);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
expected_Btype = GGML_TYPE_Q8_0;
break;
default:
return false;
}
return true;
return typeB == expected_Btype;
}
}