mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
WIP
This commit is contained in:
@@ -3164,7 +3164,8 @@ public:
|
|||||||
const float * values() const { return m_values.data(); }
|
const float * values() const { return m_values.data(); }
|
||||||
|
|
||||||
inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const;
|
inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const;
|
||||||
inline float find_best_scale(const float * xb, const float * weight, const int * best_idx) const;
|
inline void find_best_match(const float * xb, const float * weight, int * best_idx) const;
|
||||||
|
inline std::pair<float, float> find_best_scale(const float * xb, const float * weight, const int * best_idx) const;
|
||||||
|
|
||||||
static inline void set_values(uint32_t i, float * result, float scale) {
|
static inline void set_values(uint32_t i, float * result, float scale) {
|
||||||
constexpr uint32_t ka = 89226354;
|
constexpr uint32_t ka = 89226354;
|
||||||
@@ -3223,18 +3224,18 @@ QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::QuantizerIQKT() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int block_size, int group_size, int num_bits, int num_clusters>
|
template <int block_size, int group_size, int num_bits, int num_clusters>
|
||||||
float QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_scale(const float * xb, const float * weight, const int * best_idx) const {
|
std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_scale(
|
||||||
|
const float * xb, const float * weight, const int * best_idx) const {
|
||||||
float sumqx = 0, sumq2 = 0;
|
float sumqx = 0, sumq2 = 0;
|
||||||
#ifdef z__AVX2__
|
#ifdef __AVX2__
|
||||||
// TODO: fix this for kGroupSize != 8
|
|
||||||
auto vqx = _mm256_setzero_ps();
|
auto vqx = _mm256_setzero_ps();
|
||||||
auto vq2 = _mm256_setzero_ps();
|
auto vq2 = _mm256_setzero_ps();
|
||||||
for (int l = 0; l < kNg; ++l) {
|
for (int l = 0; l < kBlockSize; l += 8) {
|
||||||
auto vx = _mm256_loadu_ps(xb+8*l);
|
auto vx = _mm256_loadu_ps(xb+l);
|
||||||
auto vw = _mm256_loadu_ps(weight+8*l);
|
auto vw = _mm256_loadu_ps(weight+l);
|
||||||
auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l]) :
|
auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) :
|
||||||
_mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l+1]),
|
_mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]),
|
||||||
_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l+0]));
|
_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0]));
|
||||||
auto vqw = _mm256_mul_ps(vq, vw);
|
auto vqw = _mm256_mul_ps(vq, vw);
|
||||||
vqx = _mm256_fmadd_ps(vqw, vx, vqx);
|
vqx = _mm256_fmadd_ps(vqw, vx, vqx);
|
||||||
vq2 = _mm256_fmadd_ps(vqw, vq, vq2);
|
vq2 = _mm256_fmadd_ps(vqw, vq, vq2);
|
||||||
@@ -3252,7 +3253,149 @@ float QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
return sumq2 > 0 ? sumqx/sumq2 : 0.f;
|
return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int block_size, int group_size, int num_bits, int num_clusters>
|
||||||
|
void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_match(const float * xb, const float * weight, int * best_idx) const {
|
||||||
|
int ncluster = m_clusters.size()/kGroupSize;
|
||||||
|
#ifdef __AVX2__
|
||||||
|
if constexpr (kGroupSize == 8) {
|
||||||
|
__m256 sqx[8];
|
||||||
|
const __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
|
||||||
|
float sx[8];
|
||||||
|
int index[8];
|
||||||
|
for (int l = 0; l < kNg; ++l) {
|
||||||
|
auto xl = xb + 8*l;
|
||||||
|
auto wl = weight + 8*l;
|
||||||
|
auto vx = _mm256_loadu_ps(xl);
|
||||||
|
auto vw = _mm256_loadu_ps(wl);
|
||||||
|
auto vbest = _mm256_set1_ps(0.f);
|
||||||
|
auto best_index = _mm256_set1_epi32(-1);
|
||||||
|
float best = 0; int jbest = -1;
|
||||||
|
for (int j = 0; j < ncluster; j += 8) {
|
||||||
|
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+i));
|
||||||
|
auto sumqx = _mm256_mul_ps(vw, _mm256_mul_ps(vx, vq));
|
||||||
|
auto sumq2 = hsum_float_8(_mm256_mul_ps(vw, _mm256_mul_ps(vq, vq)));
|
||||||
|
sqx[i] = _mm256_mul_ps(_mm256_set1_ps(sumq2 > 0 ? 1/sumq2 : 0), _mm256_mul_ps(sumqx, sumqx));
|
||||||
|
}
|
||||||
|
auto score = hsum_float_8x8(sqx);
|
||||||
|
auto mask = _mm256_cmp_ps(score, vbest, _CMP_GT_OQ);
|
||||||
|
best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
|
||||||
|
_mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
|
||||||
|
vbest = _mm256_max_ps(vbest, score);
|
||||||
|
}
|
||||||
|
_mm256_store_ps(sx, vbest);
|
||||||
|
_mm256_store_si256((__m256i *)index, best_index);
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
if (sx[i] > best) { best = sx[i]; jbest = index[i]; }
|
||||||
|
}
|
||||||
|
auto& points = m_in_cluster[jbest];
|
||||||
|
GGML_ASSERT(!points.empty() && points.size()%8 == 0);
|
||||||
|
int jbest_cluster = jbest;
|
||||||
|
vbest = _mm256_set1_ps(0.f);
|
||||||
|
best_index = _mm256_set1_epi32(-1);
|
||||||
|
best = 0; jbest = -1;
|
||||||
|
for (int j = 0; j < int(points.size()); j += 8) {
|
||||||
|
auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j));
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
auto vq = _mm256_loadu_ps(m_values.data() + kGroupSize*points[j+i]);
|
||||||
|
auto sumqx = _mm256_mul_ps(vw, _mm256_mul_ps(vx, vq));
|
||||||
|
auto sumq2 = hsum_float_8(_mm256_mul_ps(vw, _mm256_mul_ps(vq, vq)));
|
||||||
|
sqx[i] = _mm256_mul_ps(_mm256_set1_ps(sumq2 > 0 ? 1/sumq2 : 0), _mm256_mul_ps(sumqx, sumqx));
|
||||||
|
}
|
||||||
|
auto score = hsum_float_8x8(sqx);
|
||||||
|
auto mask = _mm256_cmp_ps(score, vbest, _CMP_GT_OQ);
|
||||||
|
best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
|
||||||
|
_mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
|
||||||
|
vbest = _mm256_max_ps(vbest, score);
|
||||||
|
}
|
||||||
|
_mm256_store_ps(sx, vbest);
|
||||||
|
_mm256_store_si256((__m256i *)index, best_index);
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
if (sx[i] > best) { best = sx[i]; jbest = index[i]; }
|
||||||
|
}
|
||||||
|
if (jbest < 0) {
|
||||||
|
fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
best_idx[l] = jbest;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
__m128 sqx[4];
|
||||||
|
const __m128i add_idx = _mm_set_epi32(3, 2, 1, 0);
|
||||||
|
float sx[4];
|
||||||
|
int index[4];
|
||||||
|
for (int l = 0; l < kNg; ++l) {
|
||||||
|
auto xl = xb + 4*l;
|
||||||
|
auto wl = weight + 4*l;
|
||||||
|
auto vx = _mm_loadu_ps(xl);
|
||||||
|
auto sumx2 = hsum_float_4(_mm_mul_ps(vx, vx));
|
||||||
|
if (!sumx2) {
|
||||||
|
best_idx[l] = 0; continue;
|
||||||
|
}
|
||||||
|
auto vw = _mm_loadu_ps(wl);
|
||||||
|
auto vbest = _mm_set1_ps(0);
|
||||||
|
auto best_index = _mm_set1_epi32(-1);
|
||||||
|
float best = 0; int jbest = -1;
|
||||||
|
for (int j = 0; j < ncluster; j += 4) {
|
||||||
|
auto idx = _mm_add_epi32(_mm_set1_epi32(j), add_idx);
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
auto vq = _mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i));
|
||||||
|
auto sumqx = _mm_mul_ps(vw, _mm_mul_ps(vx, vq));
|
||||||
|
auto sumq2 = hsum_float_4(_mm_mul_ps(vw, _mm_mul_ps(vq, vq)));
|
||||||
|
sqx[i] = _mm_mul_ps(_mm_set1_ps(sumq2 > 0 ? 1/sumq2 : 0), _mm_mul_ps(sumqx, sumqx));
|
||||||
|
}
|
||||||
|
auto score = hsum_float_4x4(sqx);
|
||||||
|
auto mask = _mm_cmp_ps(score, vbest, _CMP_GT_OQ);
|
||||||
|
best_index = _mm_or_si128(_mm_and_si128(_mm_castps_si128(mask), idx),
|
||||||
|
_mm_andnot_si128(_mm_castps_si128(mask), best_index));
|
||||||
|
vbest = _mm_max_ps(vbest, score);
|
||||||
|
}
|
||||||
|
_mm_store_ps(sx, vbest);
|
||||||
|
_mm_store_si128((__m128i *)index, best_index);
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
if (sx[i] > best) { best = sx[i]; jbest = index[i]; }
|
||||||
|
}
|
||||||
|
GGML_ASSERT(jbest >= 0 && jbest <= int(m_in_cluster.size()));
|
||||||
|
auto& points = m_in_cluster[jbest];
|
||||||
|
GGML_ASSERT(!points.empty() && points.size()%4 == 0);
|
||||||
|
int jbest_cluster = jbest;
|
||||||
|
vbest = _mm_set1_ps(0);
|
||||||
|
best_index = _mm_set1_epi32(-1);
|
||||||
|
best = 0; jbest = -1;
|
||||||
|
for (int j = 0; j < int(points.size()); j += 4) {
|
||||||
|
auto idx = _mm_loadu_si128((const __m128i*)(points.data() + j));
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
auto vq = _mm_loadu_ps(m_values.data() + kGroupSize*points[j+i]);
|
||||||
|
auto sumqx = _mm_mul_ps(vw, _mm_mul_ps(vx, vq));
|
||||||
|
auto sumq2 = hsum_float_4(_mm_mul_ps(vw, _mm_mul_ps(vq, vq)));
|
||||||
|
sqx[i] = _mm_mul_ps(_mm_set1_ps(sumq2 > 0 ? 1/sumq2 : 0), _mm_mul_ps(sumqx, sumqx));
|
||||||
|
}
|
||||||
|
auto score = hsum_float_4x4(sqx);
|
||||||
|
auto mask = _mm_cmp_ps(score, vbest, _CMP_GT_OQ);
|
||||||
|
best_index = _mm_or_si128(_mm_and_si128(_mm_castps_si128(mask), idx),
|
||||||
|
_mm_andnot_si128(_mm_castps_si128(mask), best_index));
|
||||||
|
vbest = _mm_max_ps(vbest, score);
|
||||||
|
}
|
||||||
|
_mm_store_ps(sx, vbest);
|
||||||
|
_mm_store_si128((__m128i *)index, best_index);
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
if (sx[i] > best) { best = sx[i]; jbest = index[i]; }
|
||||||
|
}
|
||||||
|
if (jbest < 0) {
|
||||||
|
fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
best_idx[l] = jbest;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// TODO
|
||||||
|
std::memset(best_idx, 0, kNg*sizeof(int));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int block_size, int group_size, int num_bits, int num_clusters>
|
template <int block_size, int group_size, int num_bits, int num_clusters>
|
||||||
@@ -3329,6 +3472,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
|||||||
} else {
|
} else {
|
||||||
__m128 sqx[4];
|
__m128 sqx[4];
|
||||||
const __m128i add_idx = _mm_set_epi32(3, 2, 1, 0);
|
const __m128i add_idx = _mm_set_epi32(3, 2, 1, 0);
|
||||||
|
const __m128 sign_bit = _mm_set1_ps(-0.f);
|
||||||
float sx[4];
|
float sx[4];
|
||||||
int index[4];
|
int index[4];
|
||||||
auto vid = _mm_set1_ps(id);
|
auto vid = _mm_set1_ps(id);
|
||||||
@@ -3345,7 +3489,9 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
|||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
auto vq = _mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i));
|
auto vq = _mm_loadu_ps(m_clusters.data() + kGroupSize*(j+i));
|
||||||
auto vdiff = _mm_sub_ps(vq, vx);
|
auto vdiff = _mm_sub_ps(vq, vx);
|
||||||
sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff));
|
//sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff));
|
||||||
|
vdiff = _mm_andnot_ps(sign_bit, vdiff);
|
||||||
|
sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, _mm_mul_ps(vdiff, vdiff)));
|
||||||
}
|
}
|
||||||
auto score = hsum_float_4x4(sqx);
|
auto score = hsum_float_4x4(sqx);
|
||||||
auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ);
|
auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ);
|
||||||
@@ -3369,7 +3515,9 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
|||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
auto vq = _mm_loadu_ps(m_values.data() + kGroupSize*points[j+i]);
|
auto vq = _mm_loadu_ps(m_values.data() + kGroupSize*points[j+i]);
|
||||||
auto vdiff = _mm_sub_ps(vq, vx);
|
auto vdiff = _mm_sub_ps(vq, vx);
|
||||||
sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff));
|
//sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, vdiff));
|
||||||
|
vdiff = _mm_andnot_ps(sign_bit, vdiff);
|
||||||
|
sqx[i] = _mm_mul_ps(vw, _mm_mul_ps(vdiff, _mm_mul_ps(vdiff, vdiff)));
|
||||||
}
|
}
|
||||||
auto score = hsum_float_4x4(sqx);
|
auto score = hsum_float_4x4(sqx);
|
||||||
auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ);
|
auto mask = _mm_cmp_ps(score, vbest, _CMP_LT_OQ);
|
||||||
@@ -3589,7 +3737,8 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f
|
|||||||
}
|
}
|
||||||
float d = amax/96.f;
|
float d = amax/96.f;
|
||||||
quantizer.find_best_match(d, xb, weight, best_idx);
|
quantizer.find_best_match(d, xb, weight, best_idx);
|
||||||
scales[ib] = quantizer.find_best_scale(xb, weight, best_idx);
|
auto pair = quantizer.find_best_scale(xb, weight, best_idx);
|
||||||
|
scales[ib] = pair.first;
|
||||||
|
|
||||||
for (int j = 0; j < Q::kNg; ++j) qs[j] = best_idx[j];
|
for (int j = 0; j < Q::kNg; ++j) qs[j] = best_idx[j];
|
||||||
qs += Q::kNg;
|
qs += Q::kNg;
|
||||||
@@ -3665,7 +3814,8 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f
|
|||||||
const float * xb = xbl + Q::kBlockSize*ib;
|
const float * xb = xbl + Q::kBlockSize*ib;
|
||||||
const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
|
const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
|
||||||
for (int j = 0; j < Q::kNg; ++j) best_idx[j] = qs[ib*Q::kNg+j];
|
for (int j = 0; j < Q::kNg; ++j) best_idx[j] = qs[ib*Q::kNg+j];
|
||||||
scales[ib] = quantizer.find_best_scale(xb, weight, best_idx);
|
auto pair = quantizer.find_best_scale(xb, weight, best_idx);
|
||||||
|
scales[ib] = pair.first;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
float id = d ? 1/d : 0.f;
|
float id = d ? 1/d : 0.f;
|
||||||
@@ -3809,9 +3959,21 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
|
|||||||
float ax = std::abs(xb[j]);
|
float ax = std::abs(xb[j]);
|
||||||
amax = std::max(amax, ax);
|
amax = std::max(amax, ax);
|
||||||
}
|
}
|
||||||
float d = amax/96.f;
|
scales[ib] = 0;
|
||||||
quantizer.find_best_match(d, xb, weight, best_idx);
|
if (!amax) continue;
|
||||||
scales[ib] = quantizer.find_best_scale(xb, weight, best_idx);
|
float best = 0;
|
||||||
|
for (int itry = -5; itry <= 5; ++itry) {
|
||||||
|
quantizer.find_best_match(amax/(96.f + 4.f*itry), xb, weight, best_idx);
|
||||||
|
auto [d, score] = quantizer.find_best_scale(xb, weight, best_idx);
|
||||||
|
if (score > best) {
|
||||||
|
best = score;
|
||||||
|
scales[ib] = d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//float d = amax/96.f;
|
||||||
|
//quantizer.find_best_match(d, xb, weight, best_idx);
|
||||||
|
////quantizer.find_best_match(xb, weight, best_idx);
|
||||||
|
//scales[ib] = quantizer.find_best_scale(xb, weight, best_idx);
|
||||||
|
|
||||||
for (int j = 0; j < Q::kNg; ++j) {
|
for (int j = 0; j < Q::kNg; ++j) {
|
||||||
int jj = ib*Q::kNg + j;
|
int jj = ib*Q::kNg + j;
|
||||||
|
|||||||
Reference in New Issue
Block a user