Seems to be working with coopmat

This commit is contained in:
Iwan Kawrakow
2025-07-14 13:33:18 +03:00
parent 495139a3e3
commit ae12c8b616
2 changed files with 200 additions and 27 deletions

View File

@@ -503,6 +503,17 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
// ============================== ik_llama.cpp pipelines begin ========================================
vk_pipeline pipeline_fused_rms_norm_f32;
vk_pipeline pipeline_fused_mul_gelu[2];
vk_pipeline pipeline_fused_mul_silu[2];
vk_pipeline pipeline_fused_mul_relu[2];
vk_pipeline pipeline_multi_add_f32;
// ============================== ik_llama.cpp pipelines end ========================================
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
@@ -689,6 +700,13 @@ struct vk_op_glu_push_constants {
uint32_t mode; // 0: default, 1: swapped, 2: split
};
struct vk_op_multiadd_push_constants {
uint32_t ne;
uint32_t ne0, ne1;
uint32_t nb0, nb01;
uint32_t nadd;
};
struct vk_op_unary_push_constants {
uint32_t ne;
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2963,6 +2981,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
// ================================ ik_llama.cpp pipelines begin =========================================
//
ggml_vk_create_pipeline(device, device->pipeline_fused_rms_norm_f32, "fused_rms_norm_f32", fused_rms_norm_f32_len, fused_rms_norm_f32_data,
"main", 3, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_silu[0], "fused_mul_silu_f32", fused_mul_silu_f32_len, fused_mul_silu_f32_data,
"main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_silu[1], "fused_mul_silu_f16", fused_mul_silu_f16_len, fused_mul_silu_f16_data,
"main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_gelu[0], "fused_mul_gelu_f32", fused_mul_gelu_f32_len, fused_mul_gelu_f32_data,
"main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_gelu[1], "fused_mul_gelu_f16", fused_mul_gelu_f16_len, fused_mul_gelu_f16_data,
"main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_relu[0], "fused_mul_relu_f32", fused_mul_relu_f32_len, fused_mul_relu_f32_data,
"main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_relu[1], "fused_mul_relu_f16", fused_mul_relu_f16_len, fused_mul_relu_f16_data,
"main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_multi_add_f32, "multi_add_f32", multi_add_f32_len, multi_add_f32_data,
"main", 2, sizeof(vk_op_multiadd_push_constants), {512, 1, 1}, {}, 1);
//
// ================================ ik_llama.cpp pipelines end =========================================
for (auto &c : compiles) {
c.wait();
}
@@ -3720,12 +3762,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
GGML_LOG_INFO("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
GGML_LOG_WARN("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
}
}
@@ -4134,6 +4176,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
}
}
if (!(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16))) {
printf("Oops: %s, %s, prec = %d, ctx->device->fp16 = %d, ctx->device->coopmat_support = %d, ctx->device->coopmat_acc_f16_support = %d\n",
ggml_type_name(src0_type), ggml_type_name(src1_type), prec, ctx->device->fp16, ctx->device->coopmat_support, ctx->device->coopmat_acc_f16_support);
}
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
switch (src0_type) {
@@ -6807,6 +6854,38 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
// }
// }
// return nullptr;
case GGML_OP_FUSED_RMS_NORM:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_fused_rms_norm_f32;
}
return nullptr;
case GGML_OP_FUSED_MUL_UNARY:
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
(src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
(src0->type != dst->type) || (src1->type != dst->type)) {
return nullptr;
} else {
ggml_unary_op unary_op = (ggml_unary_op)dst->op_params[0];
switch (unary_op) {
case GGML_UNARY_OP_SILU:
return ctx->device->pipeline_fused_mul_silu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_GELU:
return ctx->device->pipeline_fused_mul_gelu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_RELU:
return ctx->device->pipeline_fused_mul_relu[dst->type == GGML_TYPE_F16];
default:
break;
}
return nullptr;
}
case GGML_OP_MULTI_ADD:
if (src0->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F32 ||
dst->ne[2] == 1 || dst->ne[3] == 1) {
return ctx->device->pipeline_multi_add_f32;
}
return nullptr;
default:
return nullptr;
}
@@ -6836,6 +6915,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
//case GGML_OP_CONV_2D_DW:
case GGML_OP_IM2COL:
//case GGML_OP_SET_ROWS:
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_MULTI_ADD:
return true;
default:
return false;
@@ -7120,6 +7201,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t OW = dst->ne[0];
elements = { N * OC * OH * OW, 1, 1};
} break;
case GGML_OP_FUSED_RMS_NORM:
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
@@ -7137,33 +7222,35 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_UNARY:
case GGML_OP_FUSED_MUL_UNARY:
case GGML_OP_MULTI_ADD:
//case GGML_OP_GLU:
//case GGML_OP_CONV_2D_DW:
// {
// uint32_t ne = ggml_nelements(dst);
// if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
// // Convert from number of logical elements to 2- or 4-byte units.
// ne /= ggml_blck_size(src0->type);
// if ((ggml_type_size(src0->type) % 4) == 0) {
// ne *= ggml_type_size(src0->type) / 4;
// } else {
// ne *= ggml_type_size(src0->type) / 2;
// }
// }
// // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
// // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
// // So divide by block size here before splitting into 512x512 groups.
// if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
// ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
// }
// if (ne > 262144) {
// elements = { 512, 512, CEIL_DIV(ne, 262144) };
// } else if (ne > 512) {
// elements = { 512, CEIL_DIV(ne, 512), 1 };
// } else {
// elements = { ne, 1, 1 };
// }
// } break;
{
uint32_t ne = ggml_nelements(dst);
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
// Convert from number of logical elements to 2- or 4-byte units.
ne /= ggml_blck_size(src0->type);
if ((ggml_type_size(src0->type) % 4) == 0) {
ne *= ggml_type_size(src0->type) / 4;
} else {
ne *= ggml_type_size(src0->type) / 2;
}
}
// copy_to_quant has block size of 32, and each thread does QUANT_K elements.
// Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
// So divide by block size here before splitting into 512x512 groups.
if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
}
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
elements = { 512, CEIL_DIV(ne, 512), 1 };
} else {
elements = { ne, 1, 1 };
}
} break;
//case GGML_OP_SET_ROWS:
// {
// uint32_t ne = ggml_nelements(src0);
@@ -7753,6 +7840,37 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_fused_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_FUSED_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], 1u, 1u, 1u, (uint32_t)src1->nb[0] / src1_type_size, 0u, 0u, 0u,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
op_params[0], 0.0f, 0,
}, dryrun);
}
static void ggml_vk_fused_mul_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_FUSED_MUL_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
uint32_t nadd = (uint32_t)dst->op_params[0];
ggml_vk_op_f32<vk_op_multiadd_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MULTI_ADD,
{ (uint32_t)ggml_nelements(dst), (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)(dst->nb[1]/sizeof(float)), (uint32_t)(src0->nb[1]/sizeof(float)), nadd }, dryrun);
}
#if 0
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
@@ -9049,6 +9167,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_FUSED_MUL_UNARY:
case GGML_OP_MULTI_ADD:
//case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
@@ -9116,6 +9237,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_FUSED_MUL_UNARY:
case GGML_OP_MULTI_ADD:
//case GGML_OP_L2_NORM:
case GGML_OP_UNARY:
//case GGML_OP_GLU:
@@ -9251,6 +9375,15 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_RMS_NORM_BACK:
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_FUSED_RMS_NORM:
ggml_vk_fused_rms_norm(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_FUSED_MUL_UNARY:
ggml_vk_fused_mul_unary(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_MULTI_ADD:
ggml_vk_multi_add(ctx, compute_ctx, src0, node, dryrun);
break;
//case GGML_OP_L2_NORM:
// ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -9449,6 +9582,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_FUSED_MUL_UNARY:
case GGML_OP_MULTI_ADD:
//case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
@@ -10216,6 +10352,21 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
return false;
}
break;
case GGML_OP_FUSED_MUL_UNARY:
switch ((ggml_unary_op)op->op_params[0]) {
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
return ggml_is_contiguous(op->src[0]) && ggml_are_same_shape(op->src[0], op->src[1]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
(op->src[0]->type == op->type) && (op->src[1]->type == op->type);
default:
return false;
}
break;
case GGML_OP_MULTI_ADD:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
//case GGML_OP_GLU:
// switch (ggml_get_glu_op(op)) {
// case GGML_GLU_OP_GEGLU:
@@ -10465,6 +10616,7 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_RMS_NORM:
case GGML_OP_FUSED_RMS_NORM:
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
@@ -11029,6 +11181,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
const float eps = ((float *) tensor->op_params)[0];
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
} else if (tensor->op == GGML_OP_FUSED_RMS_NORM) {
tensor_clone = ggml_fused_rms_norm(ggml_ctx, src_clone[0], src_clone[1], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_FUSED_MUL_UNARY) {
tensor_clone = ggml_fused_mul_unary(ggml_ctx, src_clone[0], src_clone[1], (ggml_unary_op)tensor->op_params[0]);
} else if (tensor->op == GGML_OP_MULTI_ADD) {
tensor_clone = ggml_multi_add(ggml_ctx, src_clone[0], tensor->op_params[0]);
} else if (tensor->op == GGML_OP_SILU_BACK) {
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_L2_NORM) {

View File

@@ -655,6 +655,21 @@ void process_shaders() {
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
// ============================== ik_llama.cpp
//
string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("fused_mul_gelu_f16", "fused_mul_gelu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("fused_mul_gelu_f32", "fused_mul_gelu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("fused_mul_silu_f16", "fused_mul_silu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("fused_mul_silu_f32", "fused_mul_silu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("fused_mul_relu_f16", "fused_mul_relu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("fused_mul_relu_f32", "fused_mul_relu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
//
// ============================== end ik_llama.cpp
for (auto &c : compiles) {
c.wait();
}