mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
Fix FA on ARM (#346)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -16142,7 +16142,13 @@ struct FlashQKV {
|
|||||||
std::memcpy(S, fms.S, nq1*sizeof(float));
|
std::memcpy(S, fms.S, nq1*sizeof(float));
|
||||||
auto R = qkv_cache;
|
auto R = qkv_cache;
|
||||||
for (int j = 0; j < nq1; ++j) {
|
for (int j = 0; j < nq1; ++j) {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||||
|
F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
|
||||||
|
}
|
||||||
|
#else
|
||||||
std::memcpy(qkv, R, D*sizeof(float));
|
std::memcpy(qkv, R, D*sizeof(float));
|
||||||
|
#endif
|
||||||
qkv += stride_qkv;
|
qkv += stride_qkv;
|
||||||
R += D;
|
R += D;
|
||||||
}
|
}
|
||||||
@@ -16162,7 +16168,13 @@ struct FlashQKV {
|
|||||||
std::memcpy(S, fms.S, q_step*sizeof(float));
|
std::memcpy(S, fms.S, q_step*sizeof(float));
|
||||||
auto R = qkv_cache;
|
auto R = qkv_cache;
|
||||||
for (int j = 0; j < q_step; ++j) {
|
for (int j = 0; j < q_step; ++j) {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||||
|
F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
|
||||||
|
}
|
||||||
|
#else
|
||||||
std::memcpy(qkv, R, D*sizeof(float));
|
std::memcpy(qkv, R, D*sizeof(float));
|
||||||
|
#endif
|
||||||
qkv += stride_qkv;
|
qkv += stride_qkv;
|
||||||
R += D;
|
R += D;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user