mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
WIP
This commit is contained in:
@@ -277,11 +277,20 @@ static inline float hsum_float_4(__m128 x) {
|
||||
static inline float hsum_float_8(__m256 x) {
|
||||
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
|
||||
}
|
||||
static __m256 hsum_float_8x8(__m256 * accm) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
|
||||
_mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
|
||||
}
|
||||
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
|
||||
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
|
||||
}
|
||||
#endif
|
||||
|
||||
static void analyze_x(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_elements) {
|
||||
constexpr int kNumVal = 1 << 12;
|
||||
constexpr int kBlockSize = 8;
|
||||
static_assert(kNumVal%8 == 0);
|
||||
auto codes = make_values(kNumVal, kBlockSize);
|
||||
std::vector<float> sumq2(kNumVal);
|
||||
for (int j = 0; j < kNumVal; ++j) {
|
||||
@@ -307,6 +316,8 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
|
||||
int last = std::min(first + chunk, nrows);
|
||||
#ifdef __AVX2__
|
||||
__m256 vx[kBlockSize/8];
|
||||
__m256 sqx[8];
|
||||
float sx[8];
|
||||
#endif
|
||||
for (int row = first; row < last; ++row) {
|
||||
auto xr = values + row*n_per_row;
|
||||
@@ -315,16 +326,19 @@ static void analyze_x(const char * name, int nrows, int n_per_row, const float *
|
||||
auto xb = xr + kBlockSize*ib;
|
||||
#ifdef __AVX2__
|
||||
for (int l = 0; l < kBlockSize/8; ++l) vx[l] = _mm256_loadu_ps(xb+8*l);
|
||||
for (int j = 0; j < kNumVal; ++j) {
|
||||
if (!sumq2[j]) continue;
|
||||
auto sx = _mm256_setzero_ps();
|
||||
for (int l = 0; l < kBlockSize/8; ++l) {
|
||||
auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*j + 8*l);
|
||||
sx = _mm256_fmadd_ps(vx[l], qv, sx);
|
||||
for (int j = 0; j < kNumVal; j += 8) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
sqx[i] = _mm256_setzero_ps();
|
||||
for (int l = 0; l < kBlockSize/8; ++l) {
|
||||
auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*(j+i) + 8*l);
|
||||
sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]);
|
||||
}
|
||||
}
|
||||
float sumqx = hsum_float_8(sx);
|
||||
if (sumqx*sumqx > best*sumq2[j]) {
|
||||
d = sumqx/sumq2[j]; best = d*sumqx; jbest = j;
|
||||
_mm256_storeu_ps(sx, hsum_float_8x8(sqx));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
if (sumq2[j+i] > 0 && sx[i]*sx[i] > best*sumq2[j+i]) {
|
||||
d = sx[i]/sumq2[j+i]; best = d*sx[i]; jbest = j+i;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
Reference in New Issue
Block a user