diff --git a/kernels/zen/3/bli_gemm_tiny.c b/kernels/zen/3/bli_gemm_tiny.c index 4edbcd6c4..32e42490b 100644 --- a/kernels/zen/3/bli_gemm_tiny.c +++ b/kernels/zen/3/bli_gemm_tiny.c @@ -517,35 +517,55 @@ err_t bli_dgemm_tiny // Query the architecture ID arch_t id = bli_arch_query_id(); - if(m <= 24 && n <= 24 && k <= 20) - { // Pick the kernel based on the architecture ID - switch (id) - { - case BLIS_ARCH_ZEN5: - case BLIS_ARCH_ZEN4: - case BLIS_ARCH_ZEN3: - case BLIS_ARCH_ZEN2: - case BLIS_ARCH_ZEN: - return bli_dgemm_tiny_6x8_kernel - ( - 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), - 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), - transa, - transb, - m, - n, - k, - alpha, - a, rs_a0, cs_a0, - b, rs_b0, cs_b0, - beta, - c, rs_c0, cs_c0 - ); - break; - default: - return BLIS_FAILURE; - } + switch (id) + { + case BLIS_ARCH_ZEN5: + if(m<24 && ((n<=24 && k<=20) || + (n<=50 && ((m<=4 && k<=50) || (m!=8 && m!=9 && m!=16 && k<=10))))) + { + return bli_dgemm_tiny_6x8_kernel + ( + 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), + 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), + transa, + transb, + m, + n, + k, + alpha, + a, rs_a0, cs_a0, + b, rs_b0, cs_b0, + beta, + c, rs_c0, cs_c0 + ); + } + break; + case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN3: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN: + if(m <= 24 && n <= 24 && k <= 20) + { + return bli_dgemm_tiny_6x8_kernel + ( + 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), + 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), + transa, + transb, + m, + n, + k, + alpha, + a, rs_a0, cs_a0, + b, rs_b0, cs_b0, + beta, + c, rs_c0, cs_c0 + ); + } + break; + default: + return BLIS_FAILURE; } if(FALSE == bli_thread_get_is_parallel())