From 301bcd4d21f2c1d200ccb0e18d94a29e9eacb440 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 14 Aug 2024 18:12:09 +0300 Subject: [PATCH] SVD POC: simdify (AVX2) --- examples/quantize-stats/CMakeLists.txt | 6 ++ examples/quantize-stats/quantize-stats.cpp | 89 ++++++++++++++++++++-- 2 files changed, 89 insertions(+), 6 deletions(-) diff --git a/examples/quantize-stats/CMakeLists.txt b/examples/quantize-stats/CMakeLists.txt index bb986a71..4172761a 100644 --- a/examples/quantize-stats/CMakeLists.txt +++ b/examples/quantize-stats/CMakeLists.txt @@ -1,4 +1,10 @@ set(TARGET llama-quantize-stats) +if (GGML_NATIVE) + if(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + message("-- Adding march=native to ${TARGET}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() +endif() add_executable(${TARGET} quantize-stats.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT}) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index eeec4517..7b499c3b 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -19,6 +19,10 @@ #include #include +#ifdef __AVX2__ +#include +#endif + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -169,42 +173,107 @@ static void test_roundtrip_on_chunk( update_error_stats(chunk_size, input_scratch, output_scratch, stats); } -static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd) { +#ifdef __AVX2__ +static inline float hsum_float_4(__m128 x) { + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); + return _mm_cvtss_f32(x); +} +static inline float hsum_float_8(__m256 x) { + return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); +} +#endif + +static void try_svd(int n_per_row, int nrows, const float * b, float * q, int nsvd, int nsvd_iter) { constexpr int kNiter = 10; + if (nsvd_iter < 1) nsvd_iter = kNiter; auto tim1 = std::chrono::steady_clock::now(); int nelem = n_per_row*nrows; double mse = 0; + bool use_avx2 = false; +#ifdef __AVX2__ + GGML_ASSERT(n_per_row%64 == 0); + use_avx2 = true; +#endif for (int j = 0; j < nelem; ++j) { q[j] = b[j] - q[j]; mse += q[j]*q[j]; } - printf("===================== %s(%d x %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, sqrt(mse/nelem)); + printf("===================== %s(%d x %d, %d, %d): rmse = %g\n", __func__, n_per_row, nrows, nsvd, use_avx2, sqrt(mse/nelem)); + float mse_old = mse; std::vector f(n_per_row), g(nrows, 1); for (int isvd = 0; isvd < nsvd; ++isvd) { printf("--- isvd = %d\n", isvd); float norm = 1.f/nrows; - for (int iter = 0; iter < kNiter; ++iter) { + for (int iter = 0; iter < nsvd_iter; ++iter) { std::memset(f.data(), 0, f.size()*sizeof(float)); + float sum_f = 0; +#ifdef __AVX2__ + auto vnorm = _mm256_set1_ps(norm); + __m256 sums[8] = {}; + for (int part = 0; part < n_per_row/64; ++part) { + for (int row = 0; row < nrows; ++row) { + __m256 vg = _mm256_set1_ps(g[row]); + auto qr = q + row*n_per_row + 64*part; + for (int k = 0; k < 8; ++k) { + auto vq = _mm256_loadu_ps(qr + 8*k); + sums[k] = _mm256_fmadd_ps(vg, vq, sums[k]); + } + } + __m256 tot = _mm256_setzero_ps(); + for (int k = 0; k < 8; ++k) { + sums[k] = _mm256_mul_ps(vnorm, sums[k]); + _mm256_storeu_ps(f.data() + 64*part + 8*k, sums[k]); + tot = _mm256_fmadd_ps(sums[k], sums[k], tot); + sums[k] = _mm256_setzero_ps(); + } + sum_f += hsum_float_8(tot); + } +#else for (int row = 0; row < nrows; ++row) { auto qr = q + row*n_per_row; for (int j = 0; j < n_per_row; ++j) f[j] += g[row]*qr[j]; } - float sum_f = 0; for (int j = 0; j < n_per_row; ++j) { f[j] *= norm; sum_f += f[j]*f[j]; } +#endif mse = 0; float sum_g = 0; for (int row = 0; row < nrows; ++row) { auto qr = q + row*n_per_row; float sum = 0; +#ifdef __AVX2__ + __m256 vsum = _mm256_setzero_ps(); + for (int j = 0; j < n_per_row; j += 8) { + auto vq = _mm256_loadu_ps(qr + j); + auto vf = _mm256_loadu_ps(f.data() + j); + vsum = _mm256_fmadd_ps(vq, vf, vsum); + } + sum = hsum_float_8(vsum); +#else for (int j = 0; j < n_per_row; ++j) sum += qr[j]*f[j]; +#endif g[row] = sum/sum_f; sum_g += g[row]*g[row]; +#ifdef __AVX2__ + auto vg = _mm256_set1_ps(g[row]); + auto vmse = _mm256_setzero_ps(); + for (int j = 0; j < n_per_row; j += 8) { + auto vq = _mm256_loadu_ps(qr + j); + auto vf = _mm256_loadu_ps(f.data() + j); + auto vdiff = _mm256_sub_ps(vq, _mm256_mul_ps(vg, vf)); + vmse = _mm256_fmadd_ps(vdiff, vdiff, vmse); + } + mse += hsum_float_8(vmse); +#else for (int j = 0; j < n_per_row; ++j) { float diff = qr[j] - g[row]*f[j]; mse += diff*diff; } +#endif } printf(" after %d iterations: %g\n", iter+1, sqrt(mse/nelem)); + if (mse_old/mse - 1 < 1e-6f) break; + mse_old = mse; norm = 1.f/sum_g; } for (int row = 0; row < nrows; ++row) { @@ -224,7 +293,7 @@ static void try_svd(int n_per_row, int nrows, const float * b, float * q, int ns static void test_roundtrip_on_layer( std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference, const ggml_tensor * layer, std::vector & input_scratch, std::vector & quantized_scratch, - std::vector & output_scratch, error_stats & total_error, int nsvd, int max_thread = 0) { + std::vector & output_scratch, error_stats & total_error, int nsvd, int nsvd_iter, int max_thread = 0) { assert(tensor_is_contiguous(layer)); error_stats layer_error {}; uint64_t nelements = ggml_nelements(layer); @@ -277,7 +346,7 @@ static void test_roundtrip_on_layer( } if (nsvd > 0 && layer->ne[0] > 1 && layer->ne[1] > 1 && layer->ne[2] == 1 && layer->ne[3] == 1) { - try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd); + try_svd(layer->ne[0], layer->ne[1], input_scratch_ptr, output_scratch.data(), nsvd, nsvd_iter); } } @@ -290,6 +359,7 @@ int main(int argc, char ** argv) { int max_thread = 0; int nsvd = 0; + int nsvd_iter = 0; bool invalid_param = false; std::string arg; for (int i = 1; i < argc; i++) { @@ -312,6 +382,12 @@ int main(int argc, char ** argv) { break; } nsvd = atoi(argv[i]); + } else if (arg == "-ni" || arg == "--svd-iterations") { + if (++i >= argc) { + invalid_param = true; + break; + } + nsvd_iter = atoi(argv[i]); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -476,6 +552,7 @@ int main(int argc, char ** argv) { output_scratch, global_stats, nsvd, + nsvd_iter, max_thread ); }