Faster BF16 Metal dot product

This commit is contained in:
Iwan Kawrakow
2024-09-16 17:32:48 +02:00
parent c1d0af0a38
commit 8e80d15930

View File

@@ -1701,7 +1701,7 @@ void kernel_mul_mv_bf16_impl(
device const uint16_t * x = (device const uint16_t *) (src0 + offset0);
typedef union { uint32_t u; float f; } aux_t;
typedef union { uint32_t u[4]; float f[4]; } aux_t;
aux_t aux;
for (int row = 0; row < N_MV_T_T; ++row) {
@@ -1713,9 +1713,12 @@ void kernel_mul_mv_bf16_impl(
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
aux.u = x[i] << 16;
sumf += aux.f * (float) y[i];
for (int i = tiisg; i < ne00/4; i += 32) {
aux.u[0] = x[4*i+0] << 16;
aux.u[1] = x[4*i+1] << 16;
aux.u[2] = x[4*i+2] << 16;
aux.u[3] = x[4*i+3] << 16;
sumf += aux.f[0] * (float)y[4*i+0] + aux.f[1] * (float)y[4*i+1] + aux.f[2] * (float)y[4*i+2] + aux.f[3] * (float)y[4*i+3];
}
float all_sum = simd_sum(sumf);