iqk_mul_mat: be able to handle any f16/f32 combination on AVX2

But only turning on f16 x f32 and f32 x f16 for now.
This commit is contained in:
Kawrakow
2024-06-10 16:43:42 +03:00
parent 1211a4b5d0
commit 154f56a8de
2 changed files with 88 additions and 57 deletions

View File

@@ -866,22 +866,41 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
if (Ctype != GGML_TYPE_F32)
return false;
if (task == GGML_TASK_TYPE_COMPUTE && k >= 256 && Atype == GGML_TYPE_F16) {
#if defined __AVX2__ && defined __FMA__
if (Btype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
return true;
}
}
//bool is_accepted_float_type = k >= 32 && Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F32;
bool is_accepted_float_type = k >= 32 &&
((Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F32) || (Atype == GGML_TYPE_F32 && Btype == GGML_TYPE_F16));
#elif defined __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && defined __ARM_FEATURE_FMA
if (Btype == GGML_TYPE_F16) {
if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
return true;
}
}
bool is_accepted_float_type = k >= 32 && Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F16;
#else
bool is_accepted_float_type = false;
#endif
if (task == GGML_TASK_TYPE_INIT && is_accepted_float_type) {
return true;
}
if (task == GGML_TASK_TYPE_COMPUTE && is_accepted_float_type) {
if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
return true;
}
}
// if (task == GGML_TASK_TYPE_COMPUTE && k >= 32 && Atype == GGML_TYPE_F16) {
//#if defined __AVX2__ && defined __FMA__
// if (Btype == GGML_TYPE_F32) {
// if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
// return true;
// }
// }
//#elif defined __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && defined __ARM_FEATURE_FMA
// if (Btype == GGML_TYPE_F16) {
// if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
// return true;
// }
// }
//#endif
// }
switch (Atype) {
case GGML_TYPE_F32: {