* Adding gp option to llama-bench

Similar to pg, but it only looks at TG speed with a given
prompt length.

* Make q8_0_r4 work with tensor row sizes that are not a multiple of 128

They still need to be divisible by 32.

* Make q8_0_r4 work with tensor row sizes that are not a multiple of 128

.. on NEON

* Make q8_0_r4 work with tensor row sizes that are not a multiple of 128

.., on AVX2

* Make q4_0_r4 work with tensor row sizes that are not a multiple of 128

.., on AVX2

* Make q4_0_r4 work with tensor row sizes that are not a multiple of 128

... on NEON

* Make q4_0_r4 work with tensor row sizes that are not a multiple of 128

... on Zen4.

Also fix q8_0 K-cache for head sizes that are not multiple of 128.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-01-29 14:05:41 +02:00
committed by GitHub
parent f725576345
commit 4a73c25002
2 changed files with 434 additions and 166 deletions

View File

@@ -220,6 +220,7 @@ struct cmd_params {
std::vector<int> n_prompt;
std::vector<int> n_gen;
std::vector<std::pair<int, int>> n_pg;
std::vector<std::pair<int, int>> n_gp;
std::vector<int> n_batch;
std::vector<int> n_ubatch;
std::vector<ggml_type> type_k;
@@ -248,6 +249,7 @@ static const cmd_params cmd_params_defaults = {
/* n_prompt */ {512},
/* n_gen */ {128},
/* n_pg */ {},
/* n_gp */ {},
/* n_batch */ {2048},
/* n_ubatch */ {512},
/* type_k */ {GGML_TYPE_F16},
@@ -280,6 +282,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
printf(" -pg <pp,tg> (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str());
printf(" -gp <pp,tg> (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_gp, pair_str), ",").c_str());
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
printf(" -ub, --ubatch-size <n> (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str());
printf(" -ctk, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
@@ -393,6 +396,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.n_pg.push_back({std::stoi(p[0]), std::stoi(p[1])});
} else if (arg == "-gp") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = string_split<std::string>(argv[i], ',');
if (p.size() != 2) {
invalid_param = true;
break;
}
params.n_gp.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
} else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
invalid_param = true;
@@ -596,6 +610,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
if (params.n_pg.empty()) { params.n_pg = cmd_params_defaults.n_pg; }
if (params.n_gp.empty()) { params.n_gp = cmd_params_defaults.n_gp; }
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; }
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
@@ -614,7 +629,19 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
return params;
}
enum test_kind_type {
// measure mean prompt processing rate without token generation
TEST_KIND_PP,
// measure mean token generation rate without prompt processing
TEST_KIND_TG,
// measure mean prompt processing and token generation rate
TEST_KIND_PG,
// measure mean token generation rate after processing prompt of given length
TEST_KIND_GP,
};
struct cmd_params_instance {
test_kind_type test_kind;
std::string model;
int n_prompt;
int n_gen;
@@ -701,6 +728,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_PP,
/* .model = */ m,
/* .n_prompt = */ n_prompt,
/* .n_gen = */ 0,
@@ -728,6 +756,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_PP,
/* .model = */ m,
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
@@ -755,6 +784,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_PP,
/* .model = */ m,
/* .n_prompt = */ n_pg.first,
/* .n_gen = */ n_pg.second,
@@ -776,6 +806,34 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
};
instances.push_back(instance);
}
for (const auto & n_gp : params.n_gp) {
if (n_gp.first == 0 && n_gp.second == 0) {
continue;
}
cmd_params_instance instance = {
/* .test_kind = */ TEST_KIND_GP,
/* .model = */ m,
/* .n_prompt = */ n_gp.first,
/* .n_gen = */ n_gp.second,
/* .n_batch = */ nb,
/* .n_ubatch = */ nub,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
/* .n_gpu_layers = */ nl,
/* .rpc_servers = */ rpc,
/* .split_mode = */ sm,
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
/* .repack = */ params.repack,
};
instances.push_back(instance);
}
}
return instances;
@@ -816,6 +874,8 @@ struct test {
int n_gen;
std::string test_time;
std::vector<uint64_t> samples_ns;
test_kind_type test_kind;
std::string test_label;
test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) {
model_filename = inst.model;
@@ -841,11 +901,32 @@ struct test {
repack = inst.repack;
n_prompt = inst.n_prompt;
n_gen = inst.n_gen;
test_kind = inst.test_kind;
// RFC 3339 date-time format
time_t t = time(NULL);
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
test_time = buf;
// prepare test label for printing
switch (test_kind) {
case TEST_KIND_PP:
snprintf(buf, sizeof(buf), "pp%d", n_prompt);
break;
case TEST_KIND_TG:
snprintf(buf, sizeof(buf), "tg%d", n_gen);
break;
case TEST_KIND_PG:
snprintf(buf, sizeof(buf), "pp%d+tg%d", n_prompt, n_gen);
break;
case TEST_KIND_GP:
snprintf(buf, sizeof(buf), "tg%d@pp%d", n_gen, n_prompt);
break;
default:
snprintf(buf, sizeof(buf), "unknown");
break;
}
test_label = buf;
(void) ctx;
}
@@ -858,7 +939,7 @@ struct test {
}
std::vector<double> get_ts() const {
int n_tokens = n_prompt + n_gen;
int n_tokens = (test_kind == TEST_KIND_GP ? 0 : n_prompt) + n_gen;
std::vector<double> ts;
std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts), [n_tokens](uint64_t t) { return 1e9 * n_tokens / t; });
return ts;
@@ -911,7 +992,7 @@ struct test {
"tensor_split", "use_mmap", "embeddings", "repack",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
"avg_ts", "stddev_ts"
"avg_ts", "stddev_ts", "test",
};
return fields;
}
@@ -967,7 +1048,8 @@ struct test {
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
std::to_string(avg_ts()), std::to_string(stdev_ts())
std::to_string(avg_ts()), std::to_string(stdev_ts()),
test_label
};
return values;
}
@@ -1269,14 +1351,15 @@ struct markdown_printer : public printer {
value += "+RPC";
}
} else if (field == "test") {
if (t.n_prompt > 0 && t.n_gen == 0) {
snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
} else if (t.n_gen > 0 && t.n_prompt == 0) {
snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
} else {
snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
}
value = buf;
//if (t.n_prompt > 0 && t.n_gen == 0) {
// snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
//} else if (t.n_gen > 0 && t.n_prompt == 0) {
// snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
//} else {
// snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
//}
//value = buf;
value = t.test_label;
} else if (field == "t/s") {
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts());
value = buf;
@@ -1489,6 +1572,7 @@ int main(int argc, char ** argv) {
if (t.n_prompt > 0) {
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
}
if (t.test_kind == TEST_KIND_GP) t_start = get_time_ns();
if (t.n_gen > 0) {
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
}

View File

@@ -111,6 +111,15 @@ struct Perf {
#define IQK_ALWAYS_INLINE __attribute__((__always_inline__))
#endif
#if defined __x86_64__
#if defined HAVE_FANCY_SIMD
#undef HAVE_FANCY_SIMD
#endif
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
#define HAVE_FANCY_SIMD
#endif
#endif
namespace {
typedef struct {
@@ -236,6 +245,35 @@ struct MulMat {
}
static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
static inline int num_rows(ggml_type type) {
#ifdef HAVE_FANCY_SIMD
switch (type) {
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K_R4:
case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS_R4:
case GGML_TYPE_IQ2_S_R4:
case GGML_TYPE_IQ3_XXS_R4:
case GGML_TYPE_IQ3_S_R4: return 4;
case GGML_TYPE_IQ4_NL_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
}
#else
switch (type) {
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K_R4:
@@ -263,6 +301,7 @@ struct MulMat {
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
}
#endif
}
private:
template <typename Dequantizer> static void set_functions(MulMat& m);
@@ -377,13 +416,6 @@ const uint64_t keven_signs[128] = {
#if defined __x86_64__
#if defined HAVE_FANCY_SIMD
#undef HAVE_FANCY_SIMD
#endif
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
#define HAVE_FANCY_SIMD
#endif
namespace {
inline float hsum_float_4(__m128 x) {
@@ -2608,6 +2640,15 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2);
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto qy = (const block_q8_1 *)q8.y[0];
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
auto sumi = accum_q4_0_quants(v, qy[ib].qs);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1);
acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc2);
}
acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1);
info.store(ix, 0, acc1);
}
@@ -2645,6 +2686,18 @@ static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f));
prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = accum_q4_0_quants(v, qy[ib].qs);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
@@ -2664,9 +2717,38 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m4 = _mm512_set1_epi8(0xf);
int nb = n / QK4_NL;
GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[8];
auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) {
auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d));
auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d));
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
for (int j = 0; j < 4; ++j) {
auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)),
_mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1);
qx[j+0] = _mm512_and_si512(bits, m4);
qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4);
}
return scales;
};
auto dot = [&qx] (const int8_t * qy) {
auto y4l = _mm_loadu_si128((const __m128i*)qy+0);
auto y4h = _mm_loadu_si128((const __m128i*)qy+1);
auto y8l = MM256_SET_M128I(y4l, y4l);
auto y8h = MM256_SET_M128I(y4h, y4h);
auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
return sumi;
};
float d8[8*nrc_y];
for (int ix = 0; ix < nrc_x; ix += 16) {
const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx);
@@ -2676,47 +2758,25 @@ static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
_mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[4*ib4+k].d));
auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[4*ib4+k].d));
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)),
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)),
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1);
auto bits3 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+2)),
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+2), 1);
auto bits4 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+3)),
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+3), 1);
qx[0] = _mm512_and_si512(bits1, m4);
qx[1] = _mm512_and_si512(bits2, m4);
qx[2] = _mm512_and_si512(bits3, m4);
qx[3] = _mm512_and_si512(bits4, m4);
qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4);
qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4);
qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits3, 4), m4);
qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits4, 4), m4);
auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
auto y8l = MM256_SET_M128I(y4l, y4l);
auto y8h = MM256_SET_M128I(y4h, y4h);
auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
auto sumi = dot(q8.y[iy][ib4].qs+32*k);
auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq4l[ib], iq4h[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(qy[ib].qs);
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
@@ -2981,12 +3041,56 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
#endif
#ifdef HAVE_FANCY_SIMD
inline __m512i qx_r8_q8_dot_product(const __m512i * qx, const int8_t * y) {
auto y4l = _mm_loadu_si128((const __m128i*)y+0);
auto y4h = _mm_loadu_si128((const __m128i*)y+1);
auto y8l = MM256_SET_M128I(y4l, y4l);
auto y8h = MM256_SET_M128I(y4h, y4h);
auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
return sumi;
}
inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) {
auto y4l = _mm_loadu_si128((const __m128i*)y+0);
auto y4h = _mm_loadu_si128((const __m128i*)y+1);
auto yl = MM256_SET_M128I(y4l, y4l);
auto yh = MM256_SET_M128I(y4h, y4h);
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff));
sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff));
return sumi;
}
inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) {
qx[0] = _mm256_loadu_si256((const __m256i *)x+0);
qx[1] = _mm256_loadu_si256((const __m256i *)x+1);
qx[2] = _mm256_loadu_si256((const __m256i *)x+2);
qx[3] = _mm256_loadu_si256((const __m256i *)x+3);
qx[4] = _mm256_loadu_si256((const __m256i *)x+4);
qx[5] = _mm256_loadu_si256((const __m256i *)x+5);
qx[6] = _mm256_loadu_si256((const __m256i *)x+6);
qx[7] = _mm256_loadu_si256((const __m256i *)x+7);
return qx_r8_q8_dot_product(qx, y);
}
template <int nrc_y>
static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%16 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
if constexpr (nrc_y == 1) {
__m256 acc[2] = {};
__m256i qx[8];
@@ -2997,32 +3101,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
_mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)));
for (int k = 0; k < 4; ++k) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
qx[0] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
qx[1] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
qx[2] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
qx[3] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
qx[4] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4);
qx[5] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5);
qx[6] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6);
qx[7] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7);
auto y4l = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+0);
auto y4h = _mm_loadu_si128((const __m128i*)q8.y[0][ib4].qs+2*k+1);
auto yl = MM256_SET_M128I(y4l, y4l);
auto yh = MM256_SET_M128I(y4h, y4h);
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff));
sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff));
auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k]));
acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[k+4]), acc[1]);
}
}
if (4*(nb/4) < nb) {
auto qy = (const block_q8_1 *)q8.y[0];
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[1]);
}
}
info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0]));
acc[0] = acc[1] = _mm256_setzero_ps();
}
@@ -3046,27 +3140,29 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
_mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y4l = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
auto y4h = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
auto y8l = MM256_SET_M128I(y4l, y4l);
auto y8h = MM256_SET_M128I(y4h, y4h);
auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1);
auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff)));
sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff)));
auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k);
auto dy = _mm512_set1_ps(d8[8*iy+k]);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d));
auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d));
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
for (int j = 0; j < 8; ++j) {
qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)),
_mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs);
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-127.f), acc[2*iy+1], acc[2*iy+0]);
info.store(ix, iy, sum512);
@@ -3082,9 +3178,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m1 = _mm256_set1_epi16(1);
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
float d8[4*nrc_y];
__m256i qx[4], sx[4];
auto dot = [&qx, &sx, &m1] (const int8_t * qy) {
auto y128 = _mm_loadu_si128((const __m128i*)qy);
auto y = MM256_SET_M128I(y128, y128);
auto sumi1 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])))
);
auto sumi2 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])))
);
return _mm256_add_epi32(sumi1, sumi2);
};
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
@@ -3094,54 +3203,49 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
}
for (int k = 0; k < 4; ++k) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
auto q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
auto s0 = _mm256_sign_epi8(q0, q0);
auto s1 = _mm256_sign_epi8(q1, q1);
auto s2 = _mm256_sign_epi8(q2, q2);
auto s3 = _mm256_sign_epi8(q3, q3);
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
auto y = MM256_SET_M128I(y128, y128);
auto sumi1 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
);
auto sumi2 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
);
auto sumi = _mm256_add_epi32(sumi1, sumi2);
auto sumi = dot(q8.y[iy][ib4].qs+32*k);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4);
q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5);
q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6);
q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7);
s0 = _mm256_sign_epi8(q0, q0);
s1 = _mm256_sign_epi8(q1, q1);
s2 = _mm256_sign_epi8(q2, q2);
s3 = _mm256_sign_epi8(q3, q3);
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
auto y = MM256_SET_M128I(y128, y128);
auto sumi1 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
);
auto sumi2 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
);
auto sumi = _mm256_add_epi32(sumi1, sumi2);
auto sumi = dot(q8.y[iy][ib4].qs+32*k+16);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(qy[ib].qs);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(qy[ib].qs+16);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
@@ -7080,6 +7184,7 @@ struct QFBase {
static inline Acc acc_first(const Data& y, const Data& x) {
return _mm512_mul_ps(y, x);
}
static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }
static inline float hsum(Acc acc) {
return _mm512_reduce_add_ps(acc);
}
@@ -7118,6 +7223,7 @@ struct QFBase {
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return _mm256_fmadd_ps(y, x, prev);
}
static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
@@ -7190,6 +7296,44 @@ template <typename Float, int nrc_in> struct QFT final : public QFBase {
const Float * y[nrc];
};
// TBD if we want this
//template <typename Qy, typename Qx>
//IQK_NOINLINE void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
// static_assert(Qy::nrc == 1);
// int nb = n/QFBase::k_step;
// int nb4 = n/4;
// Qy y(info);
// Qx x(cx + ix0*bx, bx);
// QFBase::Data xv[2*Qx::nrc];
// QFBase::Acc acc[2*Qx::nrc];
// auto yv1 = y.load1(0, 0);
// auto yv2 = y.load1(0, 1);
// for (int ix = 0; ix < Qx::nrc; ++ix) {
// xv[2*ix+0] = x.load1(ix, 0);
// xv[2*ix+1] = x.load1(ix, 1);
// acc[2*ix+0] = QFBase::acc_first(yv1, xv[2*ix+0]);
// acc[2*ix+1] = QFBase::acc_first(yv2, xv[2*ix+1]);
// }
// for (int i = 1; i < nb/2; ++i) {
// yv1 = y.load1(0, 2*i+0);
// yv2 = y.load1(0, 2*i+1);
// for (int ix = 0; ix < Qx::nrc; ++ix) {
// xv[2*ix+0] = x.load1(ix, 2*i+0);
// xv[2*ix+1] = x.load1(ix, 2*i+1);
// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[2*ix+0]);
// acc[2*ix+1] = QFBase::acc(acc[2*ix+1], yv2, xv[2*ix+1]);
// }
// }
// for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
// yv1 = y.load_tail(0, i);
// for (int ix = 0; ix < Qx::nrc; ++ix) {
// xv[ix] = x.load_tail(ix, i);
// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[ix]);
// }
// }
// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(QFBase::add(acc[2*ix+0], acc[2*ix+1])));
//}
template <typename Qy, typename Qx>
IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
int nb = n/QFBase::k_step;
@@ -7287,12 +7431,29 @@ inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, co
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
template <int nrc_y, typename FloatX, typename FloatY>
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const char * cx = (const char *)vx;
// TBD if we want this
//if constexpr (nrc_y == 1) {
// constexpr int k_nx = 2;
// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
// mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
// }
// if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {
// int nx = nrc_x - lastx;
// switch (nx) {
// case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;
// case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;
// case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;
// }
// //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);
// }
// return;
//}
#ifdef __AVX512F__
constexpr int k_nx = 5;
#else
constexpr int k_nx = nrc_y == 1 ? 4 : 2;
#endif
const char * cx = (const char *)vx;
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
}
@@ -12146,7 +12307,6 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
Q8<nrc_y, block_q8_0_x4> q8(info);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[16];
float d8[4*nrc_y];
float32x4_t acc[2*nrc_y] = {};
@@ -12168,6 +12328,18 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = deq.prepare(ib, 0, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
auto y = vld1q_s8_x2(qy[ib].qs);
auto sumi1 = interleaved_dotq(qx+0, y);
auto sumi2 = interleaved_dotq(qx+8, y);
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, deq.result(acc[2*iy+0]));
info.store(ix+4, iy, deq.result(acc[2*iy+1]));
@@ -12312,12 +12484,32 @@ struct Q6_0_R4_Dequantizer {
const int8x16_t m32 = vdupq_n_s8(-32);
};
inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) {
auto y = vld1q_s8_x2(qy);
sumi1 = sumi2 = vdupq_n_s32(0);
sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
}
template <int nrc_y>
void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
float32x4_t acc[2*nrc_y] = {};
int8x16_t qx[16];
float d8[4*nrc_y];
@@ -12332,32 +12524,29 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi1 = vdupq_n_s32(0);
auto sumi2 = vdupq_n_s32(0);
sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
auto dy = vdupq_n_f32(d8[4*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
@@ -13033,10 +13222,10 @@ struct HelperQ80R4 : public BaseHelper<step> {
m2 = _mm256_unpacklo_epi64(t2, t3);
m3 = _mm256_unpackhi_epi64(t2, t3);
#ifdef HAVE_FANCY_SIMD
m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128));
m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128));
m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128));
m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128));
m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
#endif
_mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0);
_mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1);
@@ -13055,10 +13244,10 @@ struct HelperQ80R4 : public BaseHelper<step> {
m2 = _mm256_unpacklo_epi64(t2, t3);
m3 = _mm256_unpackhi_epi64(t2, t3);
#ifdef HAVE_FANCY_SIMD
m0 = _mm256_xor_si256(m0, _mm256_set1_epi8(-128));
m1 = _mm256_xor_si256(m1, _mm256_set1_epi8(-128));
m2 = _mm256_xor_si256(m2, _mm256_set1_epi8(-128));
m3 = _mm256_xor_si256(m3, _mm256_set1_epi8(-128));
m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
#endif
_mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0);
_mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1);
@@ -13895,16 +14084,11 @@ struct FlashQKfp32 {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
if constexpr (D >= 128) {
#ifdef HAVE_FANCY_SIMD
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq);
MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq);
#else
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
#endif
} else {
// This does not actually work until we fix K-cache to be quantized to Q8_0_x4 only if D%128 == 0
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
}
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {