Fix FA on ARM (#346)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-04-25 11:01:08 +02:00
committed by GitHub
parent f176122a3d
commit 25d1a0dca8

View File

@@ -16142,7 +16142,13 @@ struct FlashQKV {
std::memcpy(S, fms.S, nq1*sizeof(float));
auto R = qkv_cache;
for (int j = 0; j < nq1; ++j) {
#ifdef __aarch64__
for (int i = 0; i < D/F16::block_size; ++i) {
F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
}
#else
std::memcpy(qkv, R, D*sizeof(float));
#endif
qkv += stride_qkv;
R += D;
}
@@ -16162,7 +16168,13 @@ struct FlashQKV {
std::memcpy(S, fms.S, q_step*sizeof(float));
auto R = qkv_cache;
for (int j = 0; j < q_step; ++j) {
#ifdef __aarch64__
for (int i = 0; i < D/F16::block_size; ++i) {
F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
}
#else
std::memcpy(qkv, R, D*sizeof(float));
#endif
qkv += stride_qkv;
R += D;
}