From 015bcb88d4d2db12d2e98dc7db8d6e3a7c25ee1d Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 30 Mar 2022 07:16:24 -0500 Subject: [PATCH] Fixed ztrsm computational failure - Fixed memory access for edge cases such that all load are within memory boundary only. - Corrected ztrsm utility APIs for dcomplex multiplication and division. AMD-Internal: [CPUPL-2093] Change-Id: Ib2c65e7921f6391b530cd20d6ea6b50f24bd705e --- kernels/zen/3/bli_trsm_small.c | 771 ++++++++++++++++++++++++--------- 1 file changed, 567 insertions(+), 204 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 0fa8f66d5..32b7647a5 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -3891,33 +3891,20 @@ err_t bli_trsm_small */ #define DCOMPLEX_INV(a, b) {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - /*Compute denominator eliminating imaginary component*/\ - double dnm = (b.real * b.real);\ - /*multiply two times with -1 for correct result as - * dcomplex number with positive imaginary part will - * invert the sign if not multiplied twice with -1*/\ - dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ - /*Compute the final result by dividing real and imag part by dnm*/\ - a.real /= dnm;\ - a.imag /= dnm;\ +/* dcomplex inva = {1.0, 0.0};*/\ + a.real = 1.0;\ + a.imag = 0.0;\ + bli_zinvscals(b, a);\ } #define DCOMPLEX_MUL(a, b, c) {\ - double real = a.real * b.real;\ - real += ((a.imag * b.imag) * -1.0);\ - double imag = (a.real * b.imag);\ - imag += (a.imag * b.real);\ - c.real = real;\ - c.imag = imag;\ + c.real = b.real;\ + c.imag = b.imag;\ + bli_zscals(a,c);\ } #define DCOMPLEX_DIV(a, b){\ - double dnm = b.real * b.real;\ - dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ - a.real /= dnm;\ - a.imag /= dnm;\ + bli_zinvscals(b,a); \ } @@ -3946,11 +3933,8 @@ err_t bli_trsm_small #define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ if(!is_unitdiag)\ {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - DCOMPLEX_MUL(c, a, c)\ - DCOMPLEX_DIV(c, b)\ - }\ + bli_zinvscals(b, c);\ + }\ } #endif @@ -4299,6 +4283,213 @@ BLIS_INLINE err_t ztrsm_AuXB_ref _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ } + +#define BLIS_ZTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm16 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm16);\ + ymm1 = _mm256_mul_pd(ymm1, ymm16);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ + \ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm11 = _mm256_addsub_pd(ymm11, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm7);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm14);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm15);\ +} + + +#define BLIS_ZTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double * )b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ +\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm7);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ +} + + /** * Performs GEMM operation. * Two elements of column in ymm0 @@ -31943,75 +32134,160 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB if(m_rem == 3) { dim_t p_lda = 4; - if(transa) - { - for(dim_t x = 0; x < i; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm10 = _mm256_loadu_pd((double const *) - (a10 + 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - ymm11 = _mm256_loadu_pd((double const *) - (a10 + 2 + cs_a)); + if(transa) + { + dim_t x = 0; + for(x = 0; (x+3) < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *) + (a10 + 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *) + (a10 + 2 + cs_a)); - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a)); - ymm10 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a + 2)); + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + ymm1 = _mm256_set_pd(1, 1, 1, 1); - ymm1 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a + 2)); + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm1,0x31); - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup + 2), - ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda + 2), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2 + 2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3 + 2), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3 + 2), ymm9); - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + for(; (x+2) < i; x += 3) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + xmm4 = _mm_loadu_pd((double const *) + (a10 + 2)); + ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + xmm4 = _mm_loadu_pd((double const *) + (a10 + 2 + cs_a)); + ymm11 = _mm256_insertf128_pd(ymm11, xmm4, 0); - } - else - { - for(dim_t x=0;x 0; j -= d_nr) { @@ -33429,37 +33791,38 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB } else if(m_remainder == 1) { - dim_t p_lda = 2; // packed leading dimension - if(transa) - { - for(dim_t x = 0; x < m-m_remainder; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - - } - else - { - for(dim_t x=0;x 0; j -= d_nr) {