FA: timing

This commit is contained in:
Iwan Kawrakow
2025-01-18 13:47:08 +02:00
parent 7efe16f715
commit 96ce347243

View File

@@ -47,8 +47,57 @@
// For fp16/fp32 matri multiplications tiling is used to improve
// performance.
#define FA_TIMING 1
#include <utility>
#include <array>
#if FA_TIMING
#include <chrono>
#include <mutex>
struct Perf {
using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>;
std::array<double, 5> times = {};
std::mutex mutex;
bool report;
static auto cur_time() { return std::chrono::high_resolution_clock::now(); }
inline void accum(int what, const TimePoint& t1) {
auto t2 = cur_time();
auto dt = delta(t1, t2);
std::lock_guard<std::mutex> lock(mutex);
times[what] += dt;
}
inline void accum_nolock(int what, const TimePoint& t1) {
auto t2 = cur_time();
auto dt = delta(t1, t2);
times[what] += dt;
}
inline void add(const Perf& other) {
std::lock_guard<std::mutex> lock(mutex);
for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i];
}
Perf(bool r) : report(r) {}
~Perf() {
if (report) {
double tot = 0;
for (auto& t : times) tot += t;
if (!tot) return;
printf("======================= Timing: %g ms in total\n", tot);
for (int i = 0; i < int(times.size()); ++i) {
if (times[i]) {
printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%');
}
}
}
}
static Perf& instance() {
static Perf p(true);
return p;
}
static double delta(const TimePoint& t1, const TimePoint& t2) {
return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
}
};
#endif
#ifdef _MSC_VER
#define IQK_NOINLINE __declspec(noinline)
@@ -13859,17 +13908,19 @@ struct FlashQKbf16 {
fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
}
#if FA_TIMING
template <typename KHelper>
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
const char * mask, FlashMS<q_step, k_step>& fms, Perf& perf) {
auto t1 = Perf::cur_time();
#else
template <typename KHelper>
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
const char * mask, FlashMS<q_step, k_step>& fms) {
#endif
{
__m512bh qv[D/32];
if constexpr (D <= 128) {
//__m512bh vkh[D/8];
//for (int l1 = 0; l1 < k_step; l1 += 4) {
// kh.load_4(l1, vkh);
// for (int j = 0; j < q_step; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
//}
__m512bh vkh[D/4];
for (int l1 = 0; l1 < k_step; l1 += 8) {
kh.load_8(l1, vkh);
@@ -13883,10 +13934,17 @@ struct FlashQKbf16 {
}
}
}
#if FA_TIMING
perf.accum_nolock(1, t1);
t1 = Perf::cur_time();
#endif
F16::Data vk[k_step/16];
for (int j = 0; j < q_step; ++j) {
fms.update_M_S(j, vk, mask + stride_m*j);
}
#if FA_TIMING
perf.accum_nolock(2, t1);
#endif
}
template <typename KHelper>
@@ -13972,20 +14030,44 @@ struct FlashAttnBF16 {
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
ggml_bf16_t q_bf16[q_step*D];
#if FA_TIMING
Perf perf(false);
#endif
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
#if FA_TIMING
auto t1 = Perf::cur_time();
#endif
fms.init_qstep();
kh.reset_block();
vh.reset_block();
FlashQKbf16<D, q_step, k_step>::convert(stride_q, q, q_bf16);
#if FA_TIMING
perf.accum_nolock(0, t1);
#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
#if FA_TIMING
//t1 = Perf::cur_time();
FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
//perf.accum_nolock(1, t1);
t1 = Perf::cur_time();
fqkv.accumulate_qkv(vh, fms);
perf.accum_nolock(3, t1);
#else
FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
fqkv.accumulate_qkv(vh, fms);
#endif
kh.next_block();
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
#if FA_TIMING
t1 = Perf::cur_time();
#endif
fqkv.normalize_and_store(fms, stride_qkv, qkv);
#if FA_TIMING
perf.accum_nolock(4, t1);
#endif
q += q_step*stride_q;
mask += q_step*stride_m;
@@ -14007,6 +14089,9 @@ struct FlashAttnBF16 {
}
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
}
#if FA_TIMING
Perf::instance().add(perf);
#endif
}
FlashMS<q_step, k_step> fms;
@@ -14058,12 +14143,13 @@ inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int
if (nq1 >= 64) {
FlashAttnBF16<D, 64, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
else if (nq1 >= 16) {
FlashAttnBF16<D, 16, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
return;
}
if (nq1 >= 8) {
FlashAttnBF16<D, 8, k_step> fa(scale, softcap);