diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 6264deb4..eeec4517 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -168,13 +169,62 @@ static void test_roundtrip_on_chunk( update_error_stats(chunk_size, input_scratch, output_scratch, stats); } +static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd) { + constexpr int kNiter = 10; + auto tim1 = std::chrono::steady_clock::now(); + int nelem = n_per_row*nrows; + double mse = 0; + for (int j = 0; j < nelem; ++j) { + q[j] = b[j] - q[j]; + mse += q[j]*q[j]; + } + printf("===================== %s(%d x %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, sqrt(mse/nelem)); + std::vector f(n_per_row), g(nrows, 1); + for (int isvd = 0; isvd < nsvd; ++isvd) { + printf("--- isvd = %d\n", isvd); + float norm = 1.f/nrows; + for (int iter = 0; iter < kNiter; ++iter) { + std::memset(f.data(), 0, f.size()*sizeof(float)); + for (int row = 0; row < nrows; ++row) { + auto qr = q + row*n_per_row; + for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j]; + } + float sum_f = 0; + for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; } + mse = 0; + float sum_g = 0; + for (int row = 0; row < nrows; ++row) { + auto qr = q + row*n_per_row; + float sum = 0; + for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j]; + g[row] = sum/sum_f; + sum_g += g[row]*g[row]; + for (int j = 0; j < n_per_row; ++j) { + float diff = qr[j] - g[row]*f[j]; + mse += diff*diff; + } + } + printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem)); + norm = 1.f/sum_g; + } + for (int row = 0; row < nrows; ++row) { + auto qr = q + row*n_per_row; + for (int j = 0; j < n_per_row; ++j) { + qr[j] -= g[row]*f[j]; + } + g[row] = 1; + } + } + auto tim2 = std::chrono::steady_clock::now(); + printf("%s: finished in %g s\n", __func__, 1e-3*std::chrono::duration_cast(tim2-tim1).count()); +} + // Run quantization function for a single layer and update error stats static void test_roundtrip_on_layer( std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference, const ggml_tensor * layer, std::vector & input_scratch, std::vector & quantized_scratch, - std::vector & output_scratch, error_stats & total_error, int max_thread = 0 -) { + std::vector & output_scratch, error_stats & total_error, int nsvd, int max_thread = 0) { assert(tensor_is_contiguous(layer)); error_stats layer_error {}; uint64_t nelements = ggml_nelements(layer); @@ -225,6 +275,10 @@ static void test_roundtrip_on_layer( print_error_stats(name, layer_error, false); combine_error_stats(total_error, layer_error); } + + if (nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) { + try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd); + } } int main(int argc, char ** argv) { @@ -235,6 +289,7 @@ int main(int argc, char ** argv) { // read command line int max_thread = 0; + int nsvd = 0; bool invalid_param = false; std::string arg; for (int i = 1; i < argc; i++) { @@ -251,6 +306,12 @@ int main(int argc, char ** argv) { params.per_layer_stats = true; } else if (arg == "--histogram") { params.print_histogram = true; + } else if (arg == "-svd") { + if (++i >= argc) { + invalid_param = true; + break; + } + nsvd = atoi(argv[i]); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -414,6 +475,7 @@ int main(int argc, char ** argv) { quantized_scratch, output_scratch, global_stats, + nsvd, max_thread ); }