mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 10:30:27 +00:00
iq3_kt speed up quantization
Same trick as last commit applied to iq2_kt. Here we get an even larger speedup: quantization time on the Ryzen-5975WX for LLaMA-3.1-8B drops to 195 seconds from 375 seconds!
This commit is contained in:
@@ -3299,15 +3299,16 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
if (sx[i] > best) { best = sx[i]; jbest = index[i]; }
|
||||
}
|
||||
auto& points = m_in_cluster[jbest];
|
||||
auto& values = m_c_values[jbest];
|
||||
GGML_ASSERT(!points.empty() && points.size()%8 == 0);
|
||||
int jbest_cluster = jbest;
|
||||
vbest = _mm256_set1_ps(0.f);
|
||||
best_index = _mm256_set1_epi32(-1);
|
||||
best = 0; jbest = -1;
|
||||
for (int j = 0; j < int(points.size()); j += 8) {
|
||||
auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j));
|
||||
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
auto vq = _mm256_loadu_ps(m_values.data() + kGroupSize*points[j+i]);
|
||||
auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+i));
|
||||
auto sumqx = _mm256_mul_ps(vw, _mm256_mul_ps(vx, vq));
|
||||
auto sumq2 = hsum_float_8(_mm256_mul_ps(vw, _mm256_mul_ps(vq, vq)));
|
||||
sqx[i] = _mm256_mul_ps(_mm256_set1_ps(sumq2 > 0 ? 1/sumq2 : 0), _mm256_mul_ps(sumqx, sumqx));
|
||||
@@ -3327,7 +3328,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
best_idx[l] = jbest;
|
||||
best_idx[l] = points[jbest];
|
||||
}
|
||||
} else {
|
||||
__m128 sqx[4];
|
||||
@@ -3446,15 +3447,18 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
if (sx[i] < best) { best = sx[i]; jbest = index[i]; }
|
||||
}
|
||||
auto& points = m_in_cluster[jbest];
|
||||
auto& values = m_c_values[jbest];
|
||||
GGML_ASSERT(!points.empty() && points.size()%8 == 0);
|
||||
int jbest_cluster = jbest;
|
||||
vbest = _mm256_set1_ps(INFINITY);
|
||||
best_index = _mm256_set1_epi32(-1);
|
||||
best = INFINITY; jbest = -1;
|
||||
auto idx = add_idx;
|
||||
auto add8 = _mm256_set1_epi32(8);
|
||||
for (int j = 0; j < int(points.size()); j += 8) {
|
||||
auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j));
|
||||
//auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
auto vq = _mm256_loadu_ps(m_values.data() + kGroupSize*points[j+i]);
|
||||
auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+i));
|
||||
auto vdiff = _mm256_sub_ps(vq, vx);
|
||||
sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff));
|
||||
}
|
||||
@@ -3463,6 +3467,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
|
||||
_mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
|
||||
vbest = _mm256_min_ps(vbest, score);
|
||||
idx = _mm256_add_epi32(idx, add8);
|
||||
}
|
||||
_mm256_store_ps(sx, vbest);
|
||||
_mm256_store_si256((__m256i *)index, best_index);
|
||||
@@ -3473,7 +3478,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
best_idx[l] = jbest;
|
||||
best_idx[l] = points[jbest];
|
||||
}
|
||||
} else {
|
||||
__m256 sqx[4];
|
||||
@@ -3482,6 +3487,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
float sx[8];
|
||||
int index[8];
|
||||
auto vid_p = _mm256_set1_ps(id);
|
||||
auto add8 = _mm256_set1_epi32(8);
|
||||
for (int l = 0; l < kNg; ++l) {
|
||||
auto xl = xb + 4*l;
|
||||
auto wl = weight + 4*l;
|
||||
@@ -3492,8 +3498,9 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
auto vbest = _mm256_set1_ps(INFINITY);
|
||||
auto best_index = _mm256_set1_epi32(-1);
|
||||
float best = INFINITY; int jbest = -1;
|
||||
auto idx = add_idx;
|
||||
for (int j = 0; j < ncluster; j += 8) {
|
||||
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||
//auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+2*i));
|
||||
auto vdiff = _mm256_sub_ps(vq, vx_p);
|
||||
@@ -3505,6 +3512,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
|
||||
_mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
|
||||
vbest = _mm256_min_ps(vbest, score);
|
||||
idx = _mm256_add_epi32(idx, add8);
|
||||
}
|
||||
_mm256_store_ps(sx, vbest);
|
||||
_mm256_store_si256((__m256i *)index, best_index);
|
||||
@@ -3518,8 +3526,9 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
vbest = _mm256_set1_ps(INFINITY);
|
||||
best_index = _mm256_set1_epi32(-1);
|
||||
best = INFINITY; jbest = -1;
|
||||
idx = add_idx;
|
||||
for (int j = 0; j < int(points.size()); j += 8) {
|
||||
auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||
//auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+2*i));
|
||||
auto vdiff = _mm256_sub_ps(vq, vx_p);
|
||||
@@ -3531,6 +3540,7 @@ void QuantizerIQKT<block_size, group_size, num_bits, num_clusters>::find_best_ma
|
||||
best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx),
|
||||
_mm256_andnot_si256(_mm256_castps_si256(mask), best_index));
|
||||
vbest = _mm256_min_ps(vbest, score);
|
||||
idx = _mm256_add_epi32(idx, add8);
|
||||
}
|
||||
_mm256_store_ps(sx, vbest);
|
||||
_mm256_store_si256((__m256i *)index, best_index);
|
||||
|
||||
Reference in New Issue
Block a user