bitnet(scale in a separate tensor): mul -> scale on Metal

Do the mul -> scale replacement on the fly in the Metal backend.
This recovers the PP performace and cuts the TG performance
degradation in half.
This commit is contained in:
Iwan Kawrakow
2024-06-19 18:23:57 +02:00
parent d08ff0df43
commit 7f968d51b4

View File

@@ -1077,7 +1077,30 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil;
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) {
float scale;
memcpy(&scale, src1->data, sizeof(float));
//printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
int64_t n = ggml_nelements(dst);
if (n % 4 == 0) {
n /= 4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
break;
}
else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
// src1 is a row