mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-09 16:00:12 +00:00
BF16 support on Metal (#56)
* BF16 support on Metal * Faster BF16 Metal dot product --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -113,6 +113,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
||||
@@ -146,6 +148,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
||||
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
||||
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
||||
@@ -176,6 +179,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
||||
@@ -206,6 +210,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
||||
@@ -628,6 +633,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16, mul_mv_bf16_f16, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
||||
@@ -661,6 +668,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
||||
@@ -691,6 +699,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32, mul_mv_id_iq6_k_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
||||
@@ -721,6 +730,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, mul_mm_iq6_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
||||
@@ -856,9 +866,12 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
|
||||
}
|
||||
|
||||
static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
|
||||
|
||||
for (size_t i = 0, n = 3; i < n; ++i) {
|
||||
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
if (op->op != GGML_OP_MUL_MAT && op->op != GGML_OP_MUL_MAT_ID && op->op != GGML_OP_GET_ROWS) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1861,6 +1874,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
||||
@@ -1945,6 +1959,19 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
nrows = 4;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
if (src1t == GGML_TYPE_F32) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
||||
}
|
||||
else if (src1t == GGML_TYPE_F16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16].pipeline;
|
||||
}
|
||||
else {
|
||||
GGML_ABORT("not implemented");
|
||||
}
|
||||
nrows = 4;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
@@ -2229,6 +2256,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
||||
@@ -2307,6 +2335,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
nth1 = 1;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
|
||||
@@ -1667,11 +1667,121 @@ kernel void kernel_mul_mv(
|
||||
tiisg);
|
||||
}
|
||||
|
||||
template<typename T1>
|
||||
void kernel_mul_mv_bf16_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
int64_t ne00,
|
||||
int64_t ne01,
|
||||
int64_t ne02,
|
||||
uint64_t nb00,
|
||||
uint64_t nb01,
|
||||
uint64_t nb02,
|
||||
int64_t ne10,
|
||||
int64_t ne11,
|
||||
int64_t ne12,
|
||||
uint64_t nb10,
|
||||
uint64_t nb11,
|
||||
uint64_t nb12,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
uint r2,
|
||||
uint r3,
|
||||
uint3 tgpig,
|
||||
uint tiisg) {
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_MV_T_T;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||
|
||||
device const uint16_t * x = (device const uint16_t *) (src0 + offset0);
|
||||
|
||||
typedef union { uint32_t u[4]; float f[4]; } aux_t;
|
||||
aux_t aux;
|
||||
|
||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
aux.u[0] = x[4*i+0] << 16;
|
||||
aux.u[1] = x[4*i+1] << 16;
|
||||
aux.u[2] = x[4*i+2] << 16;
|
||||
aux.u[3] = x[4*i+3] << 16;
|
||||
sumf += aux.f[0] * (float)y[4*i+0] + aux.f[1] * (float)y[4*i+1] + aux.f[2] * (float)y[4*i+2] + aux.f[3] * (float)y[4*i+3];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T1>
|
||||
kernel void kernel_mul_mv_bf16(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
kernel_mul_mv_bf16_impl<T1>(
|
||||
src0,
|
||||
src1,
|
||||
dst,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3,
|
||||
tgpig,
|
||||
tiisg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
||||
template [[host_name("kernel_mul_mv_bf16_f16")]] kernel mul_mv_t kernel_mul_mv_bf16<half>;
|
||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv_bf16<float>;
|
||||
|
||||
template<typename T, typename T4>
|
||||
kernel void kernel_mul_mv_1row(
|
||||
@@ -6611,6 +6721,17 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_bf16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * src_u16 = (device const uint16_t *)src;
|
||||
typedef union { uint32_t u; float f; } aux_t;
|
||||
aux_t aux;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
aux.u = (uint32_t)src_u16[i] << 16;
|
||||
reg[i/4][i%4] = aux.f;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
@@ -7692,6 +7813,7 @@ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_ro
|
||||
|
||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_q_t kernel_get_rows_q<half4x4, 1, dequantize_bf16>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||
@@ -7732,6 +7854,7 @@ typedef decltype(kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequanti
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>>;
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_bf16>>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_0, 2, dequantize_q4_0>>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_1, 2, dequantize_q4_1>>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_0, 2, dequantize_q5_0>>;
|
||||
@@ -7769,6 +7892,7 @@ typedef decltype(kernel_mul_mm_id<DD<float4x4, 1, dequantize_f32>>) mat_mm_id_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<float4x4, 1, dequantize_f32>>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<half4x4, 1, dequantize_f16>>;
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<half4x4, 1, dequantize_bf16>>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_0, 2, dequantize_q4_0>>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_1, 2, dequantize_q4_1>>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_0, 2, dequantize_q5_0>>;
|
||||
@@ -7987,6 +8111,7 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_bf16_impl<float>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
|
||||
Reference in New Issue
Block a user