diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 501e1492..b52b15e8 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -305,6 +305,50 @@ static void do_svd_iteration(int n_per_row, int nrows, const float * q, float * } +static void try_lora(int n_per_row, int nrows, const float * x, float * q, int nsvd_iter, int verbosity = 1) { + constexpr int kNiter = 10; + if (nsvd_iter < 1) nsvd_iter = kNiter; + std::vector f(n_per_row, 1), aux(n_per_row), g(nrows, 1); + for (int iter = 0; iter < nsvd_iter; ++iter) { + float mse0 = 0; + for (int row = 0; row < nrows; ++row) { + const float * xr = x + row*n_per_row; + const float * qr = q + row*n_per_row; + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < n_per_row; ++j) { + float diff = xr[j] - g[row]*f[j]*qr[j]; + mse0 += diff*diff; + float w = f[j]*qr[j]; + sumqx += xr[j]*w; + sumq2 += w*w; + } + g[row] = sumq2 > 0 ? sumqx/sumq2 : 1; + } + std::memset(f.data(), 0, f.size()*sizeof(float)); + std::memset(aux.data(), 0, aux.size()*sizeof(float)); + for (int row = 0; row < nrows; ++row) { + const float * xr = x + row*n_per_row; + const float * qr = q + row*n_per_row; + for (int j = 0; j < n_per_row; ++j) { + float w = g[row]*qr[j]; + f[j] += w*xr[j]; + aux[j] += w*w; + } + } + for (int j = 0; j < n_per_row; ++j) if (aux[j] > 0) f[j] /= aux[j]; + float mse = 0; + for (int row = 0; row < nrows; ++row) { + const float * xr = x + row*n_per_row; + const float * qr = q + row*n_per_row; + for (int j = 0; j < n_per_row; ++j) { + float diff = xr[j] - g[row]*f[j]*qr[j]; + mse += diff*diff; + } + } + printf("%s(%d): rmse0 = %g, rmse = %g\n", __func__, iter, sqrt(mse0/(n_per_row*nrows)), sqrt(mse/(n_per_row*nrows))); + } +} + static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter, char * scratch, int verbosity = 1) { constexpr int kNiter = 10; if (nsvd_iter < 1) nsvd_iter = kNiter; @@ -407,7 +451,7 @@ static void test_roundtrip_on_layer( std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference, const ggml_tensor * layer, std::vector & input_scratch, std::vector & quantized_scratch, std::vector & output_scratch, error_stats & total_error, int nsvd_before, int nsvd_after, - int nsvd_iter, int verbosity, int max_thread = 0) { + bool do_lora, int nsvd_iter, int verbosity, int max_thread = 0) { assert(tensor_is_contiguous(layer)); error_stats layer_error {}; uint64_t nelements = ggml_nelements(layer); @@ -471,6 +515,10 @@ static void test_roundtrip_on_layer( for (auto& w : workers) w.join(); } + if (do_lora) { + try_lora(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd_iter, verbosity); + } + if (print_layer_stats) { print_error_stats(name, layer_error, false); combine_error_stats(total_error, layer_error); @@ -493,6 +541,7 @@ int main(int argc, char ** argv) { int nsvd_after = 0; int nsvd_iter = 0; int verbosity = 1; + bool do_lora = false; bool invalid_param = false; std::string arg; for (int i = 1; i < argc; i++) { @@ -509,6 +558,8 @@ int main(int argc, char ** argv) { params.per_layer_stats = true; } else if (arg == "--histogram") { params.print_histogram = true; + } else if (arg == "--lora") { + do_lora = true; } else if (arg == "--svd-before") { if (++i >= argc) { invalid_param = true; @@ -584,6 +635,10 @@ int main(int argc, char ** argv) { quantize_stats_print_usage(argc, argv); return 1; } + if (do_lora && (nsvd_before > 0 || nsvd_after > 0)) { + fprintf(stderr, "error: lora cannot be combined with SVD\n"); + return 1; + } print_build_info(); @@ -697,6 +752,7 @@ int main(int argc, char ** argv) { output_scratch, global_stats, nsvd_before, nsvd_after, + do_lora, nsvd_iter, verbosity, max_thread