mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
WIP
This commit is contained in:
@@ -427,13 +427,13 @@ static float find_best_scale(int block_size, const float * xb, const float * wei
|
||||
}
|
||||
|
||||
static void analyze_x_v2(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) {
|
||||
constexpr int kNumVal = 1 << 12;
|
||||
constexpr int kNumVal = 1 << 15;
|
||||
constexpr int kBlockSize = 32;
|
||||
constexpr int kGroupSize = 8;
|
||||
constexpr int kNg = kBlockSize/kGroupSize;
|
||||
constexpr int kSuperBlockSize = 256;
|
||||
static_assert(kNumVal%8 == 0);
|
||||
auto codes = make_values(kNumVal, kGroupSize, 31.5f);
|
||||
auto codes = make_values(kNumVal, kGroupSize, 31.75f);
|
||||
int nthread = std::max(1, int(std::thread::hardware_concurrency()/2));
|
||||
int chunk = (nrows + 8*nthread - 1)/(8*nthread);
|
||||
std::mutex mutex;
|
||||
@@ -461,14 +461,29 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
|
||||
#endif
|
||||
for (int row = first; row < last; ++row) {
|
||||
auto xr = values + row*n_per_row;
|
||||
float sigma2 = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) sigma2 += xr[j]*xr[j];
|
||||
sigma2 /= n_per_row;
|
||||
//int last_ibl = -1;
|
||||
for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) {
|
||||
//int ibl = ib*kBlockSize/kSuperBlockSize;
|
||||
//if (ibl != last_ibl) {
|
||||
// auto xbl = xr + ibl*kSuperBlockSize;
|
||||
// int n = kSuperBlockSize*(ibl + 1) <= n_per_row ? kSuperBlockSize : n_per_row - ibl*kSuperBlockSize;
|
||||
// float sumx2 = 0;
|
||||
// for (int i = 0; i < n; ++i) sumx2 += xbl[i]*xbl[i];
|
||||
// sigma2 = sumx2/n;
|
||||
// last_ibl = ibl;
|
||||
//}
|
||||
auto xb = xr + kBlockSize*ib;
|
||||
for (int i = 0; i < kBlockSize; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i];
|
||||
float d = find_best_scale(kBlockSize, xb, weight.data(), iq4k_values, 5);
|
||||
float id = d ? 1/d : 0.f;
|
||||
#ifdef __AVX2__
|
||||
auto vid = _mm256_set1_ps(id);
|
||||
for (int l = 0; l < kNg; ++l) {
|
||||
auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xb+8*l));
|
||||
auto vw = _mm256_loadu_ps(weight.data() + 8*l);
|
||||
auto vbest = _mm256_set1_ps(INFINITY);
|
||||
auto best_index = _mm256_set1_epi32(-1);
|
||||
float best = INFINITY; int jbest = -1;
|
||||
@@ -477,7 +492,7 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*(j+i));
|
||||
auto vdiff = _mm256_sub_ps(vq, vx);
|
||||
sqx[i] = _mm256_mul_ps(vdiff, vdiff);
|
||||
sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
|
||||
}
|
||||
auto score = hsum_float_8x8(sqx);
|
||||
auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ);
|
||||
@@ -496,13 +511,15 @@ static void analyze_x_v2(const char * name, int nrows, int n_per_row, const floa
|
||||
auto vq2 = _mm256_setzero_ps();
|
||||
for (int l = 0; l < kNg; ++l) {
|
||||
auto vx = _mm256_loadu_ps(xb+8*l);
|
||||
auto vw = _mm256_loadu_ps(weight.data() + 8*l);
|
||||
auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*best_idx[ib*kNg + l]);
|
||||
vqx = _mm256_fmadd_ps(vq, vx, vqx);
|
||||
vq2 = _mm256_fmadd_ps(vq, vq, vq2);
|
||||
auto vqw = _mm256_mul_ps(vq, vw);
|
||||
vqx = _mm256_fmadd_ps(vqw, vx, vqx);
|
||||
vq2 = _mm256_fmadd_ps(vqw, vq, vq2);
|
||||
}
|
||||
auto sumqx = hsum_float_8(vqx);
|
||||
auto sumq2 = hsum_float_8(vq2);
|
||||
scales[ib] = sumqx/sumq2;
|
||||
scales[ib] = sumq2 > 0 ? sumqx/sumq2 : 0.f;
|
||||
#else
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user