This commit is contained in:
Iwan Kawrakow
2025-04-21 19:03:26 +03:00
parent a7cd27f7e0
commit 39714026fe

View File

@@ -8523,7 +8523,22 @@ void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
typename Unpacker::Sum4T sum4;
ScaleHelperQ8_0 scales;
__m256 result[8];
float val[8];
auto store = [&info, &result] (int ix0) {
if constexpr (nrc_y == 1) {
info.store(ix0, 0, hsum_float_8x8(result));
}
else if constexpr (nrc_y == 2) {
auto value = hsum_float_8x8(result);
auto value1 = _mm256_extractf128_ps(value, 1);
info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
}
else {
float val[8];
_mm256_storeu_ps(val, hsum_float_8x8(result));
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
}
};
if (nb%4 == 0) {
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
for (int ix = 0; ix < 8/nrc_y; ++ix) {
@@ -8531,8 +8546,7 @@ void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
AccumType0<nrc_y, true> accum;
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
}
_mm256_storeu_ps(val, hsum_float_8x8(result));
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
store(ix0);
}
} else {
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
@@ -8541,8 +8555,7 @@ void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
AccumType0<nrc_y, false> accum;
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
}
_mm256_storeu_ps(val, hsum_float_8x8(result));
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
store(ix0);
}
}
}
@@ -8589,7 +8602,22 @@ void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
typename Unpacker::Sum4T sum4;
ScaleHelperQ8_2 scales;
__m256 result[8];
float val[8];
auto store = [&info, &result] (int ix0) {
if constexpr (nrc_y == 1) {
info.store(ix0, 0, hsum_float_8x8(result));
}
else if constexpr (nrc_y == 2) {
auto value = hsum_float_8x8(result);
auto value1 = _mm256_extractf128_ps(value, 1);
info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
}
else {
float val[8];
_mm256_storeu_ps(val, hsum_float_8x8(result));
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
}
};
if (nb%4 == 0) {
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
for (int ix = 0; ix < 8/nrc_y; ++ix) {
@@ -8597,8 +8625,7 @@ void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
AccumType1<nrc_y, true> accum;
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
}
_mm256_storeu_ps(val, hsum_float_8x8(result));
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
store(ix0);
}
} else {
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
@@ -8607,8 +8634,7 @@ void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
AccumType1<nrc_y, false> accum;
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
}
_mm256_storeu_ps(val, hsum_float_8x8(result));
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
store(ix0);
}
}
}
@@ -16337,7 +16363,6 @@ struct FlashQKV {
accumulate_qkv_1(vh, fms);
return;
}
F16::Data v[8];
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
if (fms.need_scaling[j] == 2) {
@@ -16350,6 +16375,40 @@ struct FlashQKV {
}
}
}
#ifdef __AVX512F__
if constexpr ((D/F16::block_size)%4 == 0) {
F16::Data v[16];
F16::Data vs[4];
for (int i = 0; i < D/F16::block_size; i += 4) {
for (int l = 0; l < k_step; l += 4) {
for (int k = 0; k < 4; ++k) {
vh.load(l+k, i+0, v[4*k+0], v[4*k+1]);
vh.load(l+k, i+2, v[4*k+2], v[4*k+3]);
}
for (int j = 0; j < q_step; ++j) {
auto R = qkv_cache + D*j;
auto s1 = F16::load(R + F16::block_size*(i+0));
auto s2 = F16::load(R + F16::block_size*(i+1));
auto s3 = F16::load(R + F16::block_size*(i+2));
auto s4 = F16::load(R + F16::block_size*(i+3));
F16::set4(fms.cache + k_step*j + l, vs);
for (int k = 0; k < 4; ++k) {
s1 = F16::fmadd(s1, v[4*k+0], vs[k]);
s2 = F16::fmadd(s2, v[4*k+1], vs[k]);
s3 = F16::fmadd(s3, v[4*k+2], vs[k]);
s4 = F16::fmadd(s4, v[4*k+3], vs[k]);
}
F16::store(R + F16::block_size*(i+0), s1);
F16::store(R + F16::block_size*(i+1), s2);
F16::store(R + F16::block_size*(i+2), s3);
F16::store(R + F16::block_size*(i+3), s4);
}
}
}
return;
}
#endif
F16::Data v[8];
#ifdef __AVX2__
F16::Data vs[4];
#endif
@@ -17035,7 +17094,7 @@ struct FlashAttn {
std::is_same_v<KHelper, HelperQ80<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
constexpr size_t kMaxOnStackSize = 576;
constexpr size_t kMaxOnStackSize = 18432; //576;
auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
q_size = GGML_PAD(q_size, 64);
if (q_size > kMaxOnStackSize) {