mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-05 22:10:10 +00:00
SVD POC: experimenting
Do the SVD before or after quantization, quantize the SVD result, etc. So far, none of the versions is competitive with just using more bpw in the quantization.
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "iqk/iqk_quantize.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -153,14 +154,15 @@ static bool tensor_is_contiguous(const struct ggml_tensor * tensor) {
|
||||
|
||||
static void test_roundtrip_on_chunk(
|
||||
const ggml_tensor * layer, int64_t offset, int64_t chunk_size, const ggml_type_traits_t & qfns, bool use_reference,
|
||||
float * input_scratch, char * quantized_scratch, float * output_scratch, error_stats & stats
|
||||
) {
|
||||
if (layer->type == GGML_TYPE_F16) {
|
||||
for (int i = 0; i < chunk_size; i++) {
|
||||
input_scratch[i] = ggml_get_f32_1d(layer, i + offset);
|
||||
float * input_scratch, char * quantized_scratch, float * output_scratch, error_stats & stats, bool fill_input) {
|
||||
if (fill_input) {
|
||||
if (layer->type == GGML_TYPE_F16) {
|
||||
for (int i = 0; i < chunk_size; i++) {
|
||||
input_scratch[i] = ggml_get_f32_1d(layer, i + offset);
|
||||
}
|
||||
} else {
|
||||
input_scratch = ggml_get_data_f32(layer) + offset;
|
||||
}
|
||||
} else {
|
||||
input_scratch = ggml_get_data_f32(layer) + offset;
|
||||
}
|
||||
|
||||
if (use_reference) {
|
||||
@@ -303,7 +305,7 @@ static void do_svd_iteration(int n_per_row, int nrows, const float * q, float *
|
||||
|
||||
}
|
||||
|
||||
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter, int verbosity = 1) {
|
||||
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter, char * scratch, int verbosity = 1) {
|
||||
constexpr int kNiter = 10;
|
||||
if (nsvd_iter < 1) nsvd_iter = kNiter;
|
||||
if (nsvd > nrows) nsvd = nrows;
|
||||
@@ -315,14 +317,17 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
|
||||
GGML_ASSERT(n_per_row%64 == 0);
|
||||
use_avx2 = true;
|
||||
#endif
|
||||
float max_error = 0;
|
||||
for (int j = 0; j < nelem; ++j) {
|
||||
q[j] = b[j] - q[j];
|
||||
mse += q[j]*q[j];
|
||||
max_error = std::max(max_error, std::abs(q[j]));
|
||||
}
|
||||
int nthread = std::max(1, int(std::thread::hardware_concurrency()/2));
|
||||
std::vector<std::thread> workers(nthread);
|
||||
std::vector<float> work;
|
||||
if (verbosity > 0) printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem));
|
||||
if (verbosity > 0) printf("===================== %s(%d x %d, %d, %d): rmse = %g, max_err = %g\n", __func__,
|
||||
n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem), max_error);
|
||||
float mse_old = mse;
|
||||
std::vector<float> f(n_per_row), g(nrows, 1);
|
||||
for (int isvd = 0; isvd < nsvd; ++isvd) {
|
||||
@@ -335,105 +340,29 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
|
||||
if (mse_old/this_mse - 1 < 1e-6f) break;
|
||||
mse_old = this_mse;
|
||||
}
|
||||
// for (int iter = 0; iter < nsvd_iter; ++iter) {
|
||||
// std::memset(f.data(), 0, f.size()*sizeof(float));
|
||||
// float sum_f = 0;
|
||||
//#ifdef __AVX2__
|
||||
// int part = 0;
|
||||
//#ifdef __AVX512F__
|
||||
// {
|
||||
// auto vnorm = _mm512_set1_ps(norm);
|
||||
// __m512 sums[16] = {};
|
||||
// for (; part < n_per_row/256; ++part) {
|
||||
// for (int row = 0; row < nrows; ++row) {
|
||||
// __m512 vg = _mm512_set1_ps(g[row]);
|
||||
// auto qr = q + row*n_per_row + 256*part;
|
||||
// for (int k = 0; k < 16; ++k) {
|
||||
// auto vq = _mm512_loadu_ps(qr + 16*k);
|
||||
// sums[k] = _mm512_fmadd_ps(vg, vq, sums[k]);
|
||||
// }
|
||||
// }
|
||||
// __m512 tot = _mm512_setzero_ps();
|
||||
// for (int k = 0; k < 16; ++k) {
|
||||
// sums[k] = _mm512_mul_ps(vnorm, sums[k]);
|
||||
// _mm512_storeu_ps(f.data() + 256*part + 16*k, sums[k]);
|
||||
// tot = _mm512_fmadd_ps(sums[k], sums[k], tot);
|
||||
// sums[k] = _mm512_setzero_ps();
|
||||
// }
|
||||
// sum_f += _mm512_reduce_add_ps(tot);
|
||||
// }
|
||||
// part = 4*(n_per_row/256);
|
||||
// }
|
||||
//#endif
|
||||
// if (part < n_per_row/64) {
|
||||
// auto vnorm = _mm256_set1_ps(norm);
|
||||
// __m256 sums[8] = {};
|
||||
// for (; part < n_per_row/64; ++part) {
|
||||
// for (int row = 0; row < nrows; ++row) {
|
||||
// __m256 vg = _mm256_set1_ps(g[row]);
|
||||
// auto qr = q + row*n_per_row + 64*part;
|
||||
// for (int k = 0; k < 8; ++k) {
|
||||
// auto vq = _mm256_loadu_ps(qr + 8*k);
|
||||
// sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]);
|
||||
// }
|
||||
// }
|
||||
// __m256 tot = _mm256_setzero_ps();
|
||||
// for (int k = 0; k < 8; ++k) {
|
||||
// sums[k] = _mm256_mul_ps(vnorm, sums[k]);
|
||||
// _mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]);
|
||||
// tot = _mm256_fmadd_ps(sums[k], sums[k], tot);
|
||||
// sums[k] = _mm256_setzero_ps();
|
||||
// }
|
||||
// sum_f += hsum_float_8(tot);
|
||||
// }
|
||||
// }
|
||||
//#else
|
||||
// for (int row = 0; row < nrows; ++row) {
|
||||
// auto qr = q + row*n_per_row;
|
||||
// for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j];
|
||||
// }
|
||||
// for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; }
|
||||
//#endif
|
||||
// mse = 0;
|
||||
// float sum_g = 0;
|
||||
// for (int row = 0; row < nrows; ++row) {
|
||||
// auto qr = q + row*n_per_row;
|
||||
// float sum = 0;
|
||||
//#ifdef __AVX2__
|
||||
// __m256 vsum = _mm256_setzero_ps();
|
||||
// for (int j = 0; j < n_per_row; j += 8) {
|
||||
// auto vq = _mm256_loadu_ps(qr + j);
|
||||
// auto vf = _mm256_loadu_ps(f.data() + j);
|
||||
// vsum = _mm256_fmadd_ps(vq, vf, vsum);
|
||||
// }
|
||||
// sum = hsum_float_8(vsum);
|
||||
//#else
|
||||
// for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
|
||||
//#endif
|
||||
// g[row] = sum/sum_f;
|
||||
// sum_g += g[row]*g[row];
|
||||
//#ifdef __AVX2__
|
||||
// auto vg = _mm256_set1_ps(g[row]);
|
||||
// auto vmse = _mm256_setzero_ps();
|
||||
// for (int j = 0; j < n_per_row; j += 8) {
|
||||
// auto vq = _mm256_loadu_ps(qr + j);
|
||||
// auto vf = _mm256_loadu_ps(f.data() + j);
|
||||
// auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf));
|
||||
// vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse);
|
||||
// }
|
||||
// mse += hsum_float_8(vmse);
|
||||
//#else
|
||||
// for (int j = 0; j < n_per_row; ++j) {
|
||||
// float diff = qr[j] - g[row]*f[j];
|
||||
// mse += diff*diff;
|
||||
// }
|
||||
//#endif
|
||||
// }
|
||||
// printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem));
|
||||
// if (mse_old/mse - 1 < 1e-6f) break;
|
||||
// mse_old = mse;
|
||||
// norm = 1.f/sum_g;
|
||||
// }
|
||||
if (true) {
|
||||
quantize_iq2_k(f.data(), (block_iq2_k *)scratch, 1, n_per_row, nullptr);
|
||||
dequantize_row_iq2_k((block_iq2_k *)scratch, f.data(), n_per_row);
|
||||
quantize_iq2_k(g.data(), (block_iq2_k *)scratch, 1, nrows, nullptr);
|
||||
dequantize_row_iq2_k((block_iq2_k *)scratch, g.data(), nrows);
|
||||
//quantize_iq4_k(f.data(), (block_iq4_k *)scratch, 1, n_per_row, nullptr);
|
||||
//dequantize_row_iq4_k((block_iq4_k *)scratch, f.data(), n_per_row);
|
||||
//quantize_iq4_k(g.data(), (block_iq4_k *)scratch, 1, nrows, nullptr);
|
||||
//dequantize_row_iq4_k((block_iq4_k *)scratch, g.data(), nrows);
|
||||
}
|
||||
#ifdef __AVX2__
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
auto vg = _mm256_set1_ps(g[row]);
|
||||
for (int j = 0; j < n_per_row; j += 8) {
|
||||
auto vf = _mm256_loadu_ps(f.data() + j);
|
||||
auto vq = _mm256_loadu_ps(qr + j);
|
||||
vq = _mm256_sub_ps(vq, _mm256_mul_ps(vf, vg));
|
||||
_mm256_storeu_ps(qr + j, vq);
|
||||
}
|
||||
g[row] = 1;
|
||||
}
|
||||
#else
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
@@ -441,12 +370,41 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns
|
||||
}
|
||||
g[row] = 1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
auto tim2 = std::chrono::steady_clock::now();
|
||||
if (verbosity > 0) printf("%s: finished in %g s. Final rmse = %g\n", __func__, 1e-3*std::chrono::duration_cast<std::chrono::milliseconds>(tim2-tim1).count(), sqrt(mse_old/nelem));
|
||||
if (verbosity > 0) {
|
||||
max_error = 0;
|
||||
#ifdef __AVX2__
|
||||
auto vmax = _mm256_setzero_ps();
|
||||
auto sign = _mm256_set1_ps(-0.0f);
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; j += 8) {
|
||||
auto vq = _mm256_loadu_ps(qr + j);
|
||||
vmax = _mm256_max_ps(vmax, _mm256_andnot_ps(sign, vq));
|
||||
}
|
||||
}
|
||||
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(vmax, 1), _mm256_castps256_ps128(vmax));
|
||||
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
|
||||
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
|
||||
max_error = _mm_cvtss_f32(max4);
|
||||
#else
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
max_error = std::max(max_error, std::abs(qr[j]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
printf("%s: finished in %g s. Final rmse = %g max_error = %g\n", __func__,
|
||||
1e-3*std::chrono::duration_cast<std::chrono::milliseconds>(tim2-tim1).count(), sqrt(mse_old/nelem), max_error);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define SVD_BEFORE true
|
||||
|
||||
// Run quantization function for a single layer and update error stats
|
||||
static void test_roundtrip_on_layer(
|
||||
std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference,
|
||||
@@ -461,8 +419,26 @@ static void test_roundtrip_on_layer(
|
||||
if (input_scratch.size() < nelements) input_scratch.resize(nelements);
|
||||
input_scratch_ptr = input_scratch.data();
|
||||
}
|
||||
if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements);
|
||||
if (output_scratch.size() < nelements) output_scratch.resize(nelements);
|
||||
if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements);
|
||||
|
||||
bool fill_input = true;
|
||||
if (SVD_BEFORE && nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) {
|
||||
if (layer->type == GGML_TYPE_F16) {
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
input_scratch[i] = ggml_get_f32_1d(layer, i);
|
||||
}
|
||||
} else {
|
||||
printf("%s: f32 is not supported\n", __func__);
|
||||
return;
|
||||
//input_scratch = ggml_get_data_f32(layer) + 0;
|
||||
}
|
||||
std::memset(output_scratch.data(), 0, nelements*sizeof(float));
|
||||
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter, quantized_scratch.data(), verbosity);
|
||||
std::memcpy(input_scratch_ptr, output_scratch.data(), nelements*sizeof(float));
|
||||
fill_input = false;
|
||||
//return;
|
||||
}
|
||||
|
||||
if (max_thread < 1) max_thread = std::thread::hardware_concurrency();
|
||||
int chunk_size = 32*512;
|
||||
@@ -470,13 +446,13 @@ static void test_roundtrip_on_layer(
|
||||
|
||||
if (num_chunks < 2 || max_thread < 2) {
|
||||
test_roundtrip_on_chunk(layer, 0, nelements, qfns, use_reference, input_scratch_ptr, quantized_scratch.data(),
|
||||
output_scratch.data(), print_layer_stats ? layer_error : total_error);
|
||||
output_scratch.data(), print_layer_stats ? layer_error : total_error, fill_input);
|
||||
} else {
|
||||
auto & stats = print_layer_stats ? layer_error : total_error;
|
||||
std::mutex mutex;
|
||||
uint64_t counter = 0;
|
||||
auto compute = [&mutex, &counter, &stats, &qfns, nelements, layer, use_reference, input_scratch_ptr,
|
||||
&quantized_scratch, &output_scratch, chunk_size] () {
|
||||
&quantized_scratch, &output_scratch, chunk_size, fill_input] () {
|
||||
error_stats local_stats {};
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
@@ -488,7 +464,7 @@ static void test_roundtrip_on_layer(
|
||||
lock.unlock();
|
||||
uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset;
|
||||
test_roundtrip_on_chunk(layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset,
|
||||
quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats);
|
||||
quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats, fill_input);
|
||||
}
|
||||
};
|
||||
int nthread = std::min(num_chunks, max_thread);
|
||||
@@ -503,8 +479,8 @@ static void test_roundtrip_on_layer(
|
||||
combine_error_stats(total_error, layer_error);
|
||||
}
|
||||
|
||||
if (nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) {
|
||||
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter, verbosity);
|
||||
if (!SVD_BEFORE && nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) {
|
||||
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter, quantized_scratch.data(), verbosity);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user