Fix GCC compilation errors on ARM (#309)

* Fix GCC compilation errors on ARM

* One more

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-04-03 15:50:53 +02:00
committed by GitHub
parent 07dbc1aa06
commit 2ee6263e24

View File

@@ -12735,7 +12735,7 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
int nb = n / 32; int nb = n / 32;
GGML_ASSERT(nb%4 == 0); GGML_ASSERT(nb%4 == 0);
uint8x16_t qx[8]; uint8x16_t qx[8];
int32x4_t acc[nrc_y] = {}; float32x4_t acc[nrc_y] = {};
auto ms = vdup_n_u16(0x8000); auto ms = vdup_n_u16(0x8000);
auto mask = vdupq_n_s8(0x03); auto mask = vdupq_n_s8(0x03);
float d8[4*nrc_y]; float d8[4*nrc_y];
@@ -14140,7 +14140,7 @@ void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& i
auto scale2 = vmulq_f32(scale2_x, scale_y); auto scale2 = vmulq_f32(scale2_x, scale_y);
info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0]))); info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1]))); info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f); acc[2*iy+0] = acc[2*iy+1] = vdupq_n_s32(0.f);
} }
} }
} }
@@ -14823,11 +14823,11 @@ inline float32x4_t v_tanh(float32x4_t x) {
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
} }
inline float32x4_t v_tanh(float16x8_t x) { //inline float32x4_t v_tanh(float16x8_t x) {
auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); // auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); // auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); // return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
} //}
inline float32x4_t v_silu(float32x4_t x) { inline float32x4_t v_silu(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t zero = vdupq_n_f32(0.0f); const float32x4_t zero = vdupq_n_f32(0.0f);
@@ -15671,7 +15671,9 @@ struct HelperQ60 final : public BaseHelper<step> {
auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0; auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0;
#ifdef __aarch64__ #ifdef __aarch64__
// TODO // TODO
auto vd = F16::set1(*(const float16_t *)&dl->d); const float16_t * d16 = (const float16_t *)&dl->d;
auto vd = F16::set1(d16[0]);
//auto vd = F16::set1(*(const float16_t *)&dl->d);
auto qh8 = vld1_u8(dl->qh); auto qh8 = vld1_u8(dl->qh);
auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
auto qs = vld1q_u8(dl->qs); auto qs = vld1q_u8(dl->qs);
@@ -15819,7 +15821,7 @@ struct FlashMS {
return vmaxvq_f32(vmax); return vmaxvq_f32(vmax);
} }
inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) { inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) {
auto vzero = vdupq_n_f32(0); auto vzero = vdupq_n_f16(0);
auto vinf = vdupq_n_f32(-INFINITY); auto vinf = vdupq_n_f32(-INFINITY);
for (int l = 0; l < k_step/8; ++l) { for (int l = 0; l < k_step/8; ++l) {
auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l)); auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l));
@@ -15827,9 +15829,9 @@ struct FlashMS {
auto vm2 = vzip2q_u16(vm, vm); auto vm2 = vzip2q_u16(vm, vm);
auto kq = vld1q_f32_x2(cache + k_step*j + 8*l); auto kq = vld1q_f32_x2(cache + k_step*j + 8*l);
vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1), vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1),
vbicq_u32(vinf, vm1))); vbicq_u32(vreinterpretq_u32_f32(vinf), vm1)));
vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2), vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2),
vbicq_u32(vinf, vm2))); vbicq_u32(vreinterpretq_u32_f32(vinf), vm2)));
} }
float32x4_t vmax = vdupq_n_f32(-INFINITY); float32x4_t vmax = vdupq_n_f32(-INFINITY);
auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale)); auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));