From c44347e1371e2e81020ce9ce9b70843b1d2bfcf4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 30 Sep 2024 19:38:15 +0300 Subject: [PATCH] Be able to do SVD before and after quantization --- examples/quantize-stats/quantize-stats.cpp | 32 ++++++++++++---------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 33fdf3b7..501e1492 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -340,7 +340,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns if (mse_old/this_mse - 1 < 1e-6f) break; mse_old = this_mse; } - if (true) { + if (false) { quantize_iq2_k(f.data(), (block_iq2_k *)scratch, 1, n_per_row, nullptr); dequantize_row_iq2_k((block_iq2_k *)scratch, f.data(), n_per_row); quantize_iq2_k(g.data(), (block_iq2_k *)scratch, 1, nrows, nullptr); @@ -402,14 +402,12 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns } } - -#define SVD_BEFORE true - // Run quantization function for a single layer and update error stats static void test_roundtrip_on_layer( std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference, const ggml_tensor * layer, std::vector & input_scratch, std::vector & quantized_scratch, - std::vector & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int verbosity, int max_thread = 0) { + std::vector & output_scratch, error_stats & total_error, int nsvd_before, int nsvd_after, + int nsvd_iter, int verbosity, int max_thread = 0) { assert(tensor_is_contiguous(layer)); error_stats layer_error {}; uint64_t nelements = ggml_nelements(layer); @@ -423,7 +421,7 @@ static void test_roundtrip_on_layer( if (quantized_scratch.size() < 4*nelements) quantized_scratch.resize(4*nelements); bool fill_input = true; - if (SVD_BEFORE && nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) { + if (nsvd_before > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) { if (layer->type == GGML_TYPE_F16) { for (int i = 0; i < nelements; i++) { input_scratch[i] = ggml_get_f32_1d(layer, i); @@ -434,10 +432,9 @@ static void test_roundtrip_on_layer( //input_scratch = ggml_get_data_f32(layer) + 0; } std::memset(output_scratch.data(), 0, nelements*sizeof(float)); - try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter, quantized_scratch.data(), verbosity); + try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd_before, nsvd_iter, quantized_scratch.data(), verbosity); std::memcpy(input_scratch_ptr, output_scratch.data(), nelements*sizeof(float)); fill_input = false; - //return; } if (max_thread < 1) max_thread = std::thread::hardware_concurrency(); @@ -479,8 +476,8 @@ static void test_roundtrip_on_layer( combine_error_stats(total_error, layer_error); } - if (!SVD_BEFORE && nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) { - try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter, quantized_scratch.data(), verbosity); + if (nsvd_after > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) { + try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd_after, nsvd_iter, quantized_scratch.data(), verbosity); } } @@ -492,7 +489,8 @@ int main(int argc, char ** argv) { // read command line int max_thread = 0; - int nsvd = 0; + int nsvd_before = 0; + int nsvd_after = 0; int nsvd_iter = 0; int verbosity = 1; bool invalid_param = false; @@ -511,12 +509,18 @@ int main(int argc, char ** argv) { params.per_layer_stats = true; } else if (arg == "--histogram") { params.print_histogram = true; - } else if (arg == "-svd") { + } else if (arg == "--svd-before") { if (++i >= argc) { invalid_param = true; break; } - nsvd = atoi(argv[i]); + nsvd_before = atoi(argv[i]); + } else if (arg == "--svd-after") { + if (++i >= argc) { + invalid_param = true; + break; + } + nsvd_after = atoi(argv[i]); } else if (arg == "-ni" || arg == "--svd-iterations") { if (++i >= argc) { invalid_param = true; @@ -692,7 +696,7 @@ int main(int argc, char ** argv) { quantized_scratch, output_scratch, global_stats, - nsvd, + nsvd_before, nsvd_after, nsvd_iter, verbosity, max_thread