mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 01:50:01 +00:00
WIP
This commit is contained in:
@@ -8523,7 +8523,22 @@ void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
typename Unpacker::Sum4T sum4;
|
||||
ScaleHelperQ8_0 scales;
|
||||
__m256 result[8];
|
||||
float val[8];
|
||||
auto store = [&info, &result] (int ix0) {
|
||||
if constexpr (nrc_y == 1) {
|
||||
info.store(ix0, 0, hsum_float_8x8(result));
|
||||
}
|
||||
else if constexpr (nrc_y == 2) {
|
||||
auto value = hsum_float_8x8(result);
|
||||
auto value1 = _mm256_extractf128_ps(value, 1);
|
||||
info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
|
||||
info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
|
||||
}
|
||||
else {
|
||||
float val[8];
|
||||
_mm256_storeu_ps(val, hsum_float_8x8(result));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
|
||||
}
|
||||
};
|
||||
if (nb%4 == 0) {
|
||||
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
|
||||
for (int ix = 0; ix < 8/nrc_y; ++ix) {
|
||||
@@ -8531,8 +8546,7 @@ void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
AccumType0<nrc_y, true> accum;
|
||||
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
|
||||
}
|
||||
_mm256_storeu_ps(val, hsum_float_8x8(result));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
|
||||
store(ix0);
|
||||
}
|
||||
} else {
|
||||
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
|
||||
@@ -8541,8 +8555,7 @@ void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
AccumType0<nrc_y, false> accum;
|
||||
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
|
||||
}
|
||||
_mm256_storeu_ps(val, hsum_float_8x8(result));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
|
||||
store(ix0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8589,7 +8602,22 @@ void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
typename Unpacker::Sum4T sum4;
|
||||
ScaleHelperQ8_2 scales;
|
||||
__m256 result[8];
|
||||
float val[8];
|
||||
auto store = [&info, &result] (int ix0) {
|
||||
if constexpr (nrc_y == 1) {
|
||||
info.store(ix0, 0, hsum_float_8x8(result));
|
||||
}
|
||||
else if constexpr (nrc_y == 2) {
|
||||
auto value = hsum_float_8x8(result);
|
||||
auto value1 = _mm256_extractf128_ps(value, 1);
|
||||
info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88));
|
||||
info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd));
|
||||
}
|
||||
else {
|
||||
float val[8];
|
||||
_mm256_storeu_ps(val, hsum_float_8x8(result));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
|
||||
}
|
||||
};
|
||||
if (nb%4 == 0) {
|
||||
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
|
||||
for (int ix = 0; ix < 8/nrc_y; ++ix) {
|
||||
@@ -8597,8 +8625,7 @@ void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
AccumType1<nrc_y, true> accum;
|
||||
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
|
||||
}
|
||||
_mm256_storeu_ps(val, hsum_float_8x8(result));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
|
||||
store(ix0);
|
||||
}
|
||||
} else {
|
||||
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
|
||||
@@ -8607,8 +8634,7 @@ void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
AccumType1<nrc_y, false> accum;
|
||||
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
|
||||
}
|
||||
_mm256_storeu_ps(val, hsum_float_8x8(result));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
|
||||
store(ix0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -16337,7 +16363,6 @@ struct FlashQKV {
|
||||
accumulate_qkv_1(vh, fms);
|
||||
return;
|
||||
}
|
||||
F16::Data v[8];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
if (fms.need_scaling[j] == 2) {
|
||||
@@ -16350,6 +16375,40 @@ struct FlashQKV {
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef __AVX512F__
|
||||
if constexpr ((D/F16::block_size)%4 == 0) {
|
||||
F16::Data v[16];
|
||||
F16::Data vs[4];
|
||||
for (int i = 0; i < D/F16::block_size; i += 4) {
|
||||
for (int l = 0; l < k_step; l += 4) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
vh.load(l+k, i+0, v[4*k+0], v[4*k+1]);
|
||||
vh.load(l+k, i+2, v[4*k+2], v[4*k+3]);
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
auto s1 = F16::load(R + F16::block_size*(i+0));
|
||||
auto s2 = F16::load(R + F16::block_size*(i+1));
|
||||
auto s3 = F16::load(R + F16::block_size*(i+2));
|
||||
auto s4 = F16::load(R + F16::block_size*(i+3));
|
||||
F16::set4(fms.cache + k_step*j + l, vs);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
s1 = F16::fmadd(s1, v[4*k+0], vs[k]);
|
||||
s2 = F16::fmadd(s2, v[4*k+1], vs[k]);
|
||||
s3 = F16::fmadd(s3, v[4*k+2], vs[k]);
|
||||
s4 = F16::fmadd(s4, v[4*k+3], vs[k]);
|
||||
}
|
||||
F16::store(R + F16::block_size*(i+0), s1);
|
||||
F16::store(R + F16::block_size*(i+1), s2);
|
||||
F16::store(R + F16::block_size*(i+2), s3);
|
||||
F16::store(R + F16::block_size*(i+3), s4);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
F16::Data v[8];
|
||||
#ifdef __AVX2__
|
||||
F16::Data vs[4];
|
||||
#endif
|
||||
@@ -17035,7 +17094,7 @@ struct FlashAttn {
|
||||
std::is_same_v<KHelper, HelperQ80<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
|
||||
constexpr size_t kMaxOnStackSize = 576;
|
||||
constexpr size_t kMaxOnStackSize = 18432; //576;
|
||||
auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
|
||||
q_size = GGML_PAD(q_size, 64);
|
||||
if (q_size > kMaxOnStackSize) {
|
||||
|
||||
Reference in New Issue
Block a user