mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-21 13:44:10 +00:00
POC SVD: try involving the quantized weights.
This commit is contained in:
@@ -305,6 +305,50 @@ static void do_svd_iteration(int n_per_row, int nrows, const float * q, float *
|
||||
|
||||
}
|
||||
|
||||
static void try_lora(int n_per_row, int nrows, const float * x, float * q, int nsvd_iter, int verbosity = 1) {
|
||||
constexpr int kNiter = 10;
|
||||
if (nsvd_iter < 1) nsvd_iter = kNiter;
|
||||
std::vector<float> f(n_per_row, 1), aux(n_per_row), g(nrows, 1);
|
||||
for (int iter = 0; iter < nsvd_iter; ++iter) {
|
||||
float mse0 = 0;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
const float * xr = x + row*n_per_row;
|
||||
const float * qr = q + row*n_per_row;
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
float diff = xr[j] - g[row]*f[j]*qr[j];
|
||||
mse0 += diff*diff;
|
||||
float w = f[j]*qr[j];
|
||||
sumqx += xr[j]*w;
|
||||
sumq2 += w*w;
|
||||
}
|
||||
g[row] = sumq2 > 0 ? sumqx/sumq2 : 1;
|
||||
}
|
||||
std::memset(f.data(), 0, f.size()*sizeof(float));
|
||||
std::memset(aux.data(), 0, aux.size()*sizeof(float));
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
const float * xr = x + row*n_per_row;
|
||||
const float * qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
float w = g[row]*qr[j];
|
||||
f[j] += w*xr[j];
|
||||
aux[j] += w*w;
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < n_per_row; ++j) if (aux[j] > 0) f[j] /= aux[j];
|
||||
float mse = 0;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
const float * xr = x + row*n_per_row;
|
||||
const float * qr = q + row*n_per_row;
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
float diff = xr[j] - g[row]*f[j]*qr[j];
|
||||
mse += diff*diff;
|
||||
}
|
||||
}
|
||||
printf("%s(%d): rmse0 = %g, rmse = %g\n", __func__, iter, sqrt(mse0/(n_per_row*nrows)), sqrt(mse/(n_per_row*nrows)));
|
||||
}
|
||||
}
|
||||
|
||||
static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter, char * scratch, int verbosity = 1) {
|
||||
constexpr int kNiter = 10;
|
||||
if (nsvd_iter < 1) nsvd_iter = kNiter;
|
||||
@@ -407,7 +451,7 @@ static void test_roundtrip_on_layer(
|
||||
std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference,
|
||||
const ggml_tensor * layer, std::vector<float> & input_scratch, std::vector<char> & quantized_scratch,
|
||||
std::vector<float> & output_scratch, error_stats & total_error, int nsvd_before, int nsvd_after,
|
||||
int nsvd_iter, int verbosity, int max_thread = 0) {
|
||||
bool do_lora, int nsvd_iter, int verbosity, int max_thread = 0) {
|
||||
assert(tensor_is_contiguous(layer));
|
||||
error_stats layer_error {};
|
||||
uint64_t nelements = ggml_nelements(layer);
|
||||
@@ -471,6 +515,10 @@ static void test_roundtrip_on_layer(
|
||||
for (auto& w : workers) w.join();
|
||||
}
|
||||
|
||||
if (do_lora) {
|
||||
try_lora(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd_iter, verbosity);
|
||||
}
|
||||
|
||||
if (print_layer_stats) {
|
||||
print_error_stats(name, layer_error, false);
|
||||
combine_error_stats(total_error, layer_error);
|
||||
@@ -493,6 +541,7 @@ int main(int argc, char ** argv) {
|
||||
int nsvd_after = 0;
|
||||
int nsvd_iter = 0;
|
||||
int verbosity = 1;
|
||||
bool do_lora = false;
|
||||
bool invalid_param = false;
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
@@ -509,6 +558,8 @@ int main(int argc, char ** argv) {
|
||||
params.per_layer_stats = true;
|
||||
} else if (arg == "--histogram") {
|
||||
params.print_histogram = true;
|
||||
} else if (arg == "--lora") {
|
||||
do_lora = true;
|
||||
} else if (arg == "--svd-before") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -584,6 +635,10 @@ int main(int argc, char ** argv) {
|
||||
quantize_stats_print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
if (do_lora && (nsvd_before > 0 || nsvd_after > 0)) {
|
||||
fprintf(stderr, "error: lora cannot be combined with SVD\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
print_build_info();
|
||||
|
||||
@@ -697,6 +752,7 @@ int main(int argc, char ** argv) {
|
||||
output_scratch,
|
||||
global_stats,
|
||||
nsvd_before, nsvd_after,
|
||||
do_lora,
|
||||
nsvd_iter,
|
||||
verbosity,
|
||||
max_thread
|
||||
|
||||
Reference in New Issue
Block a user