mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-09 16:00:12 +00:00
Metal: FA and FlashMLA (#310)
* Metal: WIP to update Metal FA implementation Dk=192, Dv=128 works, but not Dk = 576, Dv = 512 * Metal FA: go to float * WIP * Metal FA: MLA options now all work --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -225,6 +225,39 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
|
||||
@@ -276,9 +309,33 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||
@@ -290,7 +347,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_CONCAT,
|
||||
GGML_METAL_KERNEL_TYPE_CONCAT_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CONCAT_F16,
|
||||
GGML_METAL_KERNEL_TYPE_SQR,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
|
||||
@@ -793,6 +851,39 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, mul_mm_iq4_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, mul_mm_iq5_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, mul_mm_iq6_k_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16, mul_mm_f32_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16, mul_mm_f16_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16, mul_mm_bf16_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16, mul_mm_q4_0_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16, mul_mm_q4_1_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16, mul_mm_q5_0_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16, mul_mm_q5_1_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16, mul_mm_q6_0_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16, mul_mm_q8_0_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16, mul_mm_q2_K_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16, mul_mm_q3_K_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16, mul_mm_q4_K_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16, mul_mm_q5_K_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16, mul_mm_q6_K_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16, mul_mm_iq2_xxs_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16, mul_mm_iq2_xs_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16, mul_mm_iq3_xxs_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16, mul_mm_iq3_s_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16, mul_mm_iq2_s_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16, mul_mm_iq1_s_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16, mul_mm_iq1_m_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, mul_mm_iq1_bn_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, mul_mm_iq2_bn_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, mul_mm_iq4_nl_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, mul_mm_iq4_xs_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, mul_mm_iq4_ks_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16, mul_mm_iq4_kss_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16, mul_mm_iq2_k_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16, mul_mm_iq2_ks_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16, mul_mm_iq3_k_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16, mul_mm_iq4_k_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16, mul_mm_iq5_k_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16, mul_mm_iq6_k_f16, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, ctx->support_simdgroup_mm);
|
||||
@@ -844,9 +935,33 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,flash_attn_ext_f16_hk192_hv128, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,flash_attn_ext_f16_hk576_hv512, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,flash_attn_ext_q8_0_hk192_hv128, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,flash_attn_ext_q8_0_hk576_hv512, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,flash_attn_ext_vec_f16_hk192_hv128, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,flash_attn_ext_vec_f16_hk576_hv512, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80, flash_attn_ext_vec_q8_0_h80, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112, flash_attn_ext_vec_q8_0_h112, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,flash_attn_ext_vec_q8_0_hk192_hv128, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,flash_attn_ext_vec_q8_0_hk576_hv512, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||
@@ -858,7 +973,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, cpy_f32_q6_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT_F32, concat_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT_F16, concat_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
}
|
||||
@@ -1001,17 +1117,24 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
if (op->src[1]->type != GGML_TYPE_F16) {
|
||||
if (!ctx->support_simdgroup_mm) {
|
||||
return false; // TODO: over-restricted for vec-kernels
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type ||
|
||||
(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_Q8_0)) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[2]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||
return (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) ||
|
||||
(op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512);
|
||||
}
|
||||
if (op->src[0]->ne[0] == 256) {
|
||||
return false;
|
||||
}
|
||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
return (op->src[1]->ne[0] == 64 || op->src[1]->ne[0] == 80 ||
|
||||
op->src[1]->ne[0] == 96 || op->src[1]->ne[0] == 112 ||
|
||||
op->src[1]->ne[0] == 128 || op->src[1]->ne[0] == 256);
|
||||
case GGML_OP_MUL_MAT:
|
||||
return ctx->support_simdgroup_reduction &&
|
||||
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
||||
!(op->src[0]->type >= GGML_TYPE_Q4_0_R8 && op->src[0]->type <= GGML_TYPE_Q8_K_R8);
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return ctx->support_simdgroup_reduction &&
|
||||
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
||||
@@ -1157,7 +1280,18 @@ static void ggml_metal_encode_node(
|
||||
switch (dst->op) {
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
||||
GGML_ASSERT(src0->type == src1->type && src0->type == dst->type);
|
||||
|
||||
id<MTLComputePipelineState> pipeline;
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT_F32].pipeline;
|
||||
}
|
||||
else if (dst->type == GGML_TYPE_F16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT_F16].pipeline;
|
||||
}
|
||||
else {
|
||||
GGML_ABORT("CONCAT not implemented for this type");
|
||||
}
|
||||
|
||||
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
||||
|
||||
@@ -1945,7 +2079,7 @@ static void ggml_metal_encode_node(
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||
!ggml_is_transposed(src0) &&
|
||||
!ggml_is_transposed(src1) &&
|
||||
src1t == GGML_TYPE_F32 &&
|
||||
(src1t == GGML_TYPE_F32 || src1t == GGML_TYPE_F16) &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
||||
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
@@ -1960,41 +2094,84 @@ static void ggml_metal_encode_node(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break;
|
||||
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
||||
if (src1->type == GGML_TYPE_F32) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break;
|
||||
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
||||
}
|
||||
}
|
||||
else if (src1->type == GGML_TYPE_F16) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16 ].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16].pipeline; break;
|
||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16].pipeline; break;
|
||||
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16 ].pipeline; break;
|
||||
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
||||
}
|
||||
}
|
||||
else {
|
||||
GGML_ABORT("Unsupported src1 type for MUL-MAT");
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
@@ -3204,8 +3381,9 @@ static void ggml_metal_encode_node(
|
||||
GGML_ASSERT(ne11 % 32 == 0);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
||||
GGML_ASSERT(src1->type == src2->type);
|
||||
GGML_ASSERT(ne11 == ne21);
|
||||
GGML_ASSERT(ne12 == ne22);
|
||||
|
||||
struct ggml_tensor * src3 = node->src[3];
|
||||
|
||||
@@ -3250,70 +3428,189 @@ static void ggml_metal_encode_node(
|
||||
|
||||
bool use_vec_kernel = false;
|
||||
|
||||
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192 && ne00 != 576)) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
||||
}
|
||||
else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
||||
}
|
||||
else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported type: %s\n", ggml_type_name(src1->type));
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
use_vec_kernel = true;
|
||||
|
||||
switch (ne00) {
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline;
|
||||
}
|
||||
else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
if (ne00 == 192 && ne20 == 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline;
|
||||
}
|
||||
else if (ne00 == 576 && ne20 == 512) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported type: %s\n", ggml_type_name(src1->type));
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
uint64_t nb31;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint16_t n_head_log2;
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext args = {
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne_12_2 =*/ ne12,
|
||||
/*.ne_12_3 =*/ ne13,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
/*.m1 =*/ m1,
|
||||
/*.n_head_log2 =*/ n_head_log2,
|
||||
/*.logit_softcap =*/ softcap,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
if (id_src3) {
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
||||
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
||||
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
||||
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
||||
[encoder setBytes:&softcap length:sizeof(softcap) atIndex:27];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
|
||||
|
||||
if (!use_vec_kernel) {
|
||||
// half8x8 kernel
|
||||
@@ -3324,10 +3621,19 @@ static void ggml_metal_encode_node(
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// 2*(2*ncpsg + nqptg)*(nsg)
|
||||
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
||||
//
|
||||
// 16*32*(nsg)
|
||||
// the shared memory needed for the simdgroups to load the KV cache
|
||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
while (true) {
|
||||
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > ctx->device.maxThreadgroupMemoryLength) {
|
||||
break;
|
||||
}
|
||||
@@ -3338,14 +3644,14 @@ static void ggml_metal_encode_node(
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
|
||||
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
|
||||
} else {
|
||||
// half1x4 kernel
|
||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
@@ -3355,8 +3661,27 @@ static void ggml_metal_encode_node(
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// ne00 + 2*ncpsg*(nsg)
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the soft_max values and the mask
|
||||
//
|
||||
// ne00*(nsg)
|
||||
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
while (true) {
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > ctx->device.maxThreadgroupMemoryLength) {
|
||||
break;
|
||||
}
|
||||
nsgmax *= 2;
|
||||
}
|
||||
nsgmax /= 2;
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||
|
||||
int64_t nsg = 1;
|
||||
while (nsg <= nsgt) {
|
||||
@@ -3364,13 +3689,14 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
nsg /= 2;
|
||||
|
||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user