mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 14:44:09 +00:00
POC: add ability to try SVD on the difference between model and quantized model
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <chrono>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
@@ -168,13 +169,62 @@ static void test_roundtrip_on_chunk(
|
||||
update_error_stats(chunk_size, input_scratch, output_scratch, stats);
|
||||
}
|
||||
|
||||
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd) {
|
||||
constexpr int kNiter = 10;
|
||||
auto tim1 = std::chrono::steady_clock::now();
|
||||
int nelem = n_per_row*nrows;
|
||||
double mse = 0;
|
||||
for (int j = 0; j < nelem; ++j) {
|
||||
q[j] = b[j] - q[j];
|
||||
mse += q[j]*q[j];
|
||||
}
|
||||
printf("===================== %s(%d x %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, sqrt(mse/nelem));
|
||||
std::vector<float> f(n_per_row), g(nrows, 1);
|
||||
for (int isvd = 0; isvd < nsvd; ++isvd) {
|
||||
printf("--- isvd = %d\n", isvd);
|
||||
float norm = 1.f/nrows;
|
||||
for (int iter = 0; iter < kNiter; ++iter) {
|
||||
std::memset(f.data(), 0, f.size()*sizeof(float));
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j];
|
||||
}
|
||||
float sum_f = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; }
|
||||
mse = 0;
|
||||
float sum_g = 0;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
float sum = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j];
|
||||
g[row] = sum/sum_f;
|
||||
sum_g += g[row]*g[row];
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
float diff = qr[j] - g[row]*f[j];
|
||||
mse += diff*diff;
|
||||
}
|
||||
}
|
||||
printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem));
|
||||
norm = 1.f/sum_g;
|
||||
}
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
auto qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
qr[j] -= g[row]*f[j];
|
||||
}
|
||||
g[row] = 1;
|
||||
}
|
||||
}
|
||||
auto tim2 = std::chrono::steady_clock::now();
|
||||
printf("%s: finished in %g s\n", __func__, 1e-3*std::chrono::duration_cast<std::chrono::milliseconds>(tim2-tim1).count());
|
||||
}
|
||||
|
||||
|
||||
// Run quantization function for a single layer and update error stats
|
||||
static void test_roundtrip_on_layer(
|
||||
std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference,
|
||||
const ggml_tensor * layer, std::vector<float> & input_scratch, std::vector<char> & quantized_scratch,
|
||||
std::vector<float> & output_scratch, error_stats & total_error, int max_thread = 0
|
||||
) {
|
||||
std::vector<float> & output_scratch, error_stats & total_error, int nsvd, int max_thread = 0) {
|
||||
assert(tensor_is_contiguous(layer));
|
||||
error_stats layer_error {};
|
||||
uint64_t nelements = ggml_nelements(layer);
|
||||
@@ -225,6 +275,10 @@ static void test_roundtrip_on_layer(
|
||||
print_error_stats(name, layer_error, false);
|
||||
combine_error_stats(total_error, layer_error);
|
||||
}
|
||||
|
||||
if (nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) {
|
||||
try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
@@ -235,6 +289,7 @@ int main(int argc, char ** argv) {
|
||||
// read command line
|
||||
|
||||
int max_thread = 0;
|
||||
int nsvd = 0;
|
||||
bool invalid_param = false;
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
@@ -251,6 +306,12 @@ int main(int argc, char ** argv) {
|
||||
params.per_layer_stats = true;
|
||||
} else if (arg == "--histogram") {
|
||||
params.print_histogram = true;
|
||||
} else if (arg == "-svd") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
nsvd = atoi(argv[i]);
|
||||
} else if (arg == "-m" || arg == "--model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -414,6 +475,7 @@ int main(int argc, char ** argv) {
|
||||
quantized_scratch,
|
||||
output_scratch,
|
||||
global_stats,
|
||||
nsvd,
|
||||
max_thread
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user