diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 3b9ec14b6..83a7dcf7c 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -1123,16 +1123,51 @@ void dtrsm_blis_impl * In case of multithread when [m+n]<320 single thread implementation * is doing better than small multithread and native multithread */ bool is_parallel = bli_thread_get_is_parallel(); - if ((!is_parallel && ((dim_a < 1500) && (size_b < 5e6)) ) || - (is_parallel && (m0+n0)<200)) + switch(id) { - switch(id) - { - case BLIS_ARCH_ZEN5: - case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN5: #if defined(BLIS_KERNELS_ZEN4) + // In native code path, input buffers are packed. + // Let's say packed buffers improve the speed of + // computation by a factor of 'S' and it takes 'X' + // units of time to pack buffers. If a computation + // without packed buffer would have take 'T' time, + // then it would take 'T/S + X' time with packed buffers + // where S > 1. + // Time complexity of TRSM is (M^2 * N) in left variants + // and (N^2 * M) in right variants. + // Therefore time taken by Small path for left variant will be + // (M^2 * N) + // and time taken by Native path for left variant will be + // (M^2 * N) / S + X + // We should take small code path when + // (M^2 * N) < (M^2 * N) / S + X + // solving this gives us + // (M^2 * N) < (X * S) / ( S - 1) + // Here RHS is constant, which can be found using empirical data + // (X * S) / ( S - 1) is found to be around 6.3e6 on Turin + // In order the reduce the possiblity of overflow, taking log on + // both sides gives us + // 2log(m) + log(n) < 6.8 for left variant + if ( ( blis_side == BLIS_LEFT ) && + ( (log10(n0) + (2*log10(m0)) ) < 6.8 ) ) + { + ker_ft = bli_trsm_small_AVX512; + } + else if ( ( blis_side == BLIS_RIGHT ) && + ( (log10(m0) + (2*log10(n0)) ) < 6.8 ) ) + { + ker_ft = bli_trsm_small_AVX512; + } + break; +#endif // BLIS_KERNELS_ZEN4 + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + if ((!is_parallel && ((dim_a < 1500) && (size_b < 5e6)) ) || + (is_parallel && (m0+n0)<200)) + { /* For sizes where m and n < 50,avx2 kernels are performing better, - except for sizes where n is multiple of 8.*/ + except for sizes where n is multiple of 8.*/ if (((n0 % 8 == 0) && (n0 < 50)) || ((m0 > 50) && (n0 > 50))) { ker_ft = bli_trsm_small_AVX512; @@ -1141,37 +1176,61 @@ void dtrsm_blis_impl { ker_ft = bli_trsm_small; } - break; + } + break; #endif // BLIS_KERNELS_ZEN4 - case BLIS_ARCH_ZEN: - case BLIS_ARCH_ZEN2: - case BLIS_ARCH_ZEN3: - default: + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + default: + if ((!is_parallel && ((dim_a < 1500) && (size_b < 5e6)) ) || + (is_parallel && (m0+n0)<200)) + { ker_ft = bli_trsm_small; - break; - } + } + break; } #ifdef BLIS_ENABLE_OPENMP - if( (ker_ft == NULL) && (is_parallel) && - ((dim_a < 2500) && (size_b < 5e6)) ) + switch(id) { - switch(id) - { - case BLIS_ARCH_ZEN5: - case BLIS_ARCH_ZEN4: + case BLIS_ARCH_ZEN5: #if defined(BLIS_KERNELS_ZEN4) - ker_ft = bli_trsm_small_mt_AVX512; - break; + if( (is_parallel) && n0 > 10 && m0 > 10 ) + { + if ( ( blis_side == BLIS_LEFT ) && + ( (log10(n0) + (2*log10(m0)) ) < 6.8 ) ) + { + ker_ft = bli_trsm_small_mt_AVX512; + } + else if ( ( blis_side == BLIS_RIGHT ) && + ( (log10(m0) + (2*log10(n0)) ) < 6.8 ) ) + { + ker_ft = bli_trsm_small_mt_AVX512; + } + } + break; #endif// BLIS_KERNELS_ZEN4 - case BLIS_ARCH_ZEN: - case BLIS_ARCH_ZEN2: - case BLIS_ARCH_ZEN3: - default: + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + if( (ker_ft == NULL) && (is_parallel) && + ((dim_a < 2500) && (size_b < 5e6)) ) + { + ker_ft = bli_trsm_small_mt_AVX512; + } + break; +#endif// BLIS_KERNELS_ZEN4 + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + default: + if( (ker_ft == NULL) && (is_parallel) && + ((dim_a < 2500) && (size_b < 5e6)) ) + { ker_ft = bli_trsm_small_mt; - break; + } + break; } - } #endif// BLIS_ENABLE_OPENMP if(ker_ft)