mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-25 17:09:22 +00:00
FA: timing
This commit is contained in:
@@ -47,8 +47,57 @@
|
||||
// For fp16/fp32 matri multiplications tiling is used to improve
|
||||
// performance.
|
||||
|
||||
#define FA_TIMING 1
|
||||
|
||||
#include <utility>
|
||||
#include <array>
|
||||
#if FA_TIMING
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
struct Perf {
|
||||
using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>;
|
||||
std::array<double, 5> times = {};
|
||||
std::mutex mutex;
|
||||
bool report;
|
||||
static auto cur_time() { return std::chrono::high_resolution_clock::now(); }
|
||||
inline void accum(int what, const TimePoint& t1) {
|
||||
auto t2 = cur_time();
|
||||
auto dt = delta(t1, t2);
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
times[what] += dt;
|
||||
}
|
||||
inline void accum_nolock(int what, const TimePoint& t1) {
|
||||
auto t2 = cur_time();
|
||||
auto dt = delta(t1, t2);
|
||||
times[what] += dt;
|
||||
}
|
||||
inline void add(const Perf& other) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i];
|
||||
}
|
||||
Perf(bool r) : report(r) {}
|
||||
~Perf() {
|
||||
if (report) {
|
||||
double tot = 0;
|
||||
for (auto& t : times) tot += t;
|
||||
if (!tot) return;
|
||||
printf("======================= Timing: %g ms in total\n", tot);
|
||||
for (int i = 0; i < int(times.size()); ++i) {
|
||||
if (times[i]) {
|
||||
printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
static Perf& instance() {
|
||||
static Perf p(true);
|
||||
return p;
|
||||
}
|
||||
static double delta(const TimePoint& t1, const TimePoint& t2) {
|
||||
return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define IQK_NOINLINE __declspec(noinline)
|
||||
@@ -13859,17 +13908,19 @@ struct FlashQKbf16 {
|
||||
fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
|
||||
#if FA_TIMING
|
||||
template <typename KHelper>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
|
||||
const char * mask, FlashMS<q_step, k_step>& fms, Perf& perf) {
|
||||
auto t1 = Perf::cur_time();
|
||||
#else
|
||||
template <typename KHelper>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
|
||||
const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
#endif
|
||||
{
|
||||
__m512bh qv[D/32];
|
||||
if constexpr (D <= 128) {
|
||||
//__m512bh vkh[D/8];
|
||||
//for (int l1 = 0; l1 < k_step; l1 += 4) {
|
||||
// kh.load_4(l1, vkh);
|
||||
// for (int j = 0; j < q_step; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
|
||||
//}
|
||||
__m512bh vkh[D/4];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 8) {
|
||||
kh.load_8(l1, vkh);
|
||||
@@ -13883,10 +13934,17 @@ struct FlashQKbf16 {
|
||||
}
|
||||
}
|
||||
}
|
||||
#if FA_TIMING
|
||||
perf.accum_nolock(1, t1);
|
||||
t1 = Perf::cur_time();
|
||||
#endif
|
||||
F16::Data vk[k_step/16];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
#if FA_TIMING
|
||||
perf.accum_nolock(2, t1);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename KHelper>
|
||||
@@ -13972,20 +14030,44 @@ struct FlashAttnBF16 {
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float * qkv) {
|
||||
ggml_bf16_t q_bf16[q_step*D];
|
||||
#if FA_TIMING
|
||||
Perf perf(false);
|
||||
#endif
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
#if FA_TIMING
|
||||
auto t1 = Perf::cur_time();
|
||||
#endif
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
vh.reset_block();
|
||||
FlashQKbf16<D, q_step, k_step>::convert(stride_q, q, q_bf16);
|
||||
#if FA_TIMING
|
||||
perf.accum_nolock(0, t1);
|
||||
#endif
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
#if FA_TIMING
|
||||
//t1 = Perf::cur_time();
|
||||
FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
|
||||
//perf.accum_nolock(1, t1);
|
||||
t1 = Perf::cur_time();
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
perf.accum_nolock(3, t1);
|
||||
#else
|
||||
FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
#endif
|
||||
kh.next_block();
|
||||
vh.next_block();
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
#if FA_TIMING
|
||||
t1 = Perf::cur_time();
|
||||
#endif
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv);
|
||||
#if FA_TIMING
|
||||
perf.accum_nolock(4, t1);
|
||||
#endif
|
||||
|
||||
q += q_step*stride_q;
|
||||
mask += q_step*stride_m;
|
||||
@@ -14007,6 +14089,9 @@ struct FlashAttnBF16 {
|
||||
}
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
|
||||
}
|
||||
#if FA_TIMING
|
||||
Perf::instance().add(perf);
|
||||
#endif
|
||||
}
|
||||
|
||||
FlashMS<q_step, k_step> fms;
|
||||
@@ -14058,12 +14143,13 @@ inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int
|
||||
if (nq1 >= 64) {
|
||||
FlashAttnBF16<D, 64, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
return;
|
||||
}
|
||||
else if (nq1 >= 16) {
|
||||
FlashAttnBF16<D, 16, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
FlashAttnBF16<D, 8, k_step> fa(scale, softcap);
|
||||
|
||||
Reference in New Issue
Block a user