mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
Faster BF16 Metal dot product
This commit is contained in:
@@ -1701,7 +1701,7 @@ void kernel_mul_mv_bf16_impl(
|
||||
|
||||
device const uint16_t * x = (device const uint16_t *) (src0 + offset0);
|
||||
|
||||
typedef union { uint32_t u; float f; } aux_t;
|
||||
typedef union { uint32_t u[4]; float f[4]; } aux_t;
|
||||
aux_t aux;
|
||||
|
||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||
@@ -1713,9 +1713,12 @@ void kernel_mul_mv_bf16_impl(
|
||||
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00; i += 32) {
|
||||
aux.u = x[i] << 16;
|
||||
sumf += aux.f * (float) y[i];
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
aux.u[0] = x[4*i+0] << 16;
|
||||
aux.u[1] = x[4*i+1] << 16;
|
||||
aux.u[2] = x[4*i+2] << 16;
|
||||
aux.u[3] = x[4*i+3] << 16;
|
||||
sumf += aux.f[0] * (float)y[4*i+0] + aux.f[1] * (float)y[4*i+1] + aux.f[2] * (float)y[4*i+2] + aux.f[3] * (float)y[4*i+3];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
|
||||
Reference in New Issue
Block a user