mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 03:41:53 +00:00
WIP
This commit is contained in:
@@ -527,19 +527,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||||||
if (src1->type == GGML_TYPE_F16) {
|
if (src1->type == GGML_TYPE_F16) {
|
||||||
auto to_fp16 = ggml_get_to_fp16_cuda(src0->type);
|
auto to_fp16 = ggml_get_to_fp16_cuda(src0->type);
|
||||||
if (to_fp16) {
|
if (to_fp16) {
|
||||||
to_fp16(src0->data, (half *)src1->data, ggml_nrows(src0), src0->ne[1], main_stream);
|
to_fp16(src0->data, (half *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (src1->type == GGML_TYPE_F32) {
|
else if (src1->type == GGML_TYPE_F32) {
|
||||||
auto to_fp32 = ggml_get_to_fp32_cuda(src0->type);
|
auto to_fp32 = ggml_get_to_fp32_cuda(src0->type);
|
||||||
if (to_fp32) {
|
if (to_fp32) {
|
||||||
to_fp32(src0->data, (float *)src1->data, ggml_nrows(src0), src0->ne[1], main_stream);
|
to_fp32(src0->data, (float *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (src1->type == GGML_TYPE_BF16) {
|
else if (src1->type == GGML_TYPE_BF16) {
|
||||||
auto to_bf16 = ggml_get_to_bf16_cuda(src0->type);
|
auto to_bf16 = ggml_get_to_bf16_cuda(src0->type);
|
||||||
if (to_bf16) {
|
if (to_bf16) {
|
||||||
to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[1], main_stream);
|
to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -579,9 +579,21 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|||||||
return (void*) cpy_f32_f16<cpy_1_f16_f16>;
|
return (void*) cpy_f32_f16<cpy_1_f16_f16>;
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
||||||
} else {
|
} else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) {
|
||||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
if (src1->type == GGML_TYPE_F16) {
|
||||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
auto to_fp16 = ggml_get_to_fp16_cuda(src0->type);
|
||||||
GGML_ABORT("fatal error");
|
if (to_fp16) return (void*)to_fp16;
|
||||||
|
}
|
||||||
|
else if (src1->type == GGML_TYPE_F32) {
|
||||||
|
auto to_fp32 = ggml_get_to_fp32_cuda(src0->type);
|
||||||
|
if (to_fp32) return (void*)to_fp32;
|
||||||
|
}
|
||||||
|
else if (src1->type == GGML_TYPE_BF16) {
|
||||||
|
auto to_bf16 = ggml_get_to_bf16_cuda(src0->type);
|
||||||
|
if (to_bf16) return (void*)to_bf16;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||||
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user