mirror of
https://github.com/amd/blis.git
synced 2026-05-04 14:31:12 +00:00
Details: - NOTE: This is a merge commit of 'master' of git://github.com/amd/blis into 'amd-master' of flame/blis. - Fixed a bug in the downstream value of BLIS_NUM_ARCHS, which was inadvertantly not incremented when the Zen2 subconfiguration was added. - In bli_gemm_front(), added a missing conditional constraint around the call to bli_gemm_small() that ensures that the computation precision of C matches the storage precision of C. - In bli_syrk_front(), reorganized and relocated the notrans/trans logic that existed around the call to bli_syrk_small() into bli_syrk_small() to minimize the calling code footprint and also to bring that code into stylistic harmony with similar code in bli_gemm_front() and bli_trsm_front(). Also, replaced direct accessing of obj_t fields with proper accessor static functions (e.g. 'a->dim[0]' becomes 'bli_obj_length( a )'). - Added #ifdef BLIS_ENABLE_SMALL_MATRIX guard around prototypes for bli_gemm_small(), bli_syrk_small(), and bli_trsm_small(). This is strictly speaking unnecessary, but it serves as a useful visual cue to those who may be reading the files. - Removed cpp macro-protected small matrix debugging code from bli_trsm_front.c. - Added a GCC_OT_9_1_0 variable to build/config.mk.in to facilitate gcc version check for availability of -march=znver2, and added appropriate support to configure script. - Cleanups to compiler flags common to recent AMD microarchitectures in config/zen/amd_config.mk, including: removal of -march=znver1 et al. from CKVECFLAGS (since the -march flag is added within make_defs.mk); setting CRVECFLAGS similarly to CKVECFLAGS. - Cleanups to config/zen/bli_cntx_init_zen.c. - Cleanups, added comments to config/zen/make_defs.mk. - Cleanups to config/zen2/make_defs.mk, including making use of newly- added GCC_OT_9_1_0 and existing GCC_OT_6_1_0 to choose the correct set of compiler flags based on the version of gcc being used. - Reverted downstream changes to test/test_gemm.c. - Various whitespace/comment changes.
4211 lines
166 KiB
C
4211 lines
166 KiB
C
/*
|
|
|
|
BLIS
|
|
An object-based framework for developing high-performance BLAS-like
|
|
libraries.
|
|
|
|
Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc.
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
- Neither the name of The University of Texas at Austin nor the names
|
|
of its contributors may be used to endorse or promote products
|
|
derived from this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
THEORY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*/
|
|
|
|
#include "immintrin.h"
|
|
#include "xmmintrin.h"
|
|
#include "blis.h"
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX
|
|
|
|
#define MR 32
|
|
#define D_MR (MR >> 1)
|
|
#define NR 3
|
|
|
|
#define BLIS_ENABLE_PREFETCH
|
|
#define F_SCRATCH_DIM (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)
|
|
static float A_pack[F_SCRATCH_DIM] __attribute__((aligned(64)));
|
|
static float C_pack[F_SCRATCH_DIM] __attribute__((aligned(64)));
|
|
#define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 )
|
|
#define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2)
|
|
#define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2)
|
|
#define D_SCRATCH_DIM (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)
|
|
static double D_A_pack[D_SCRATCH_DIM] __attribute__((aligned(64)));
|
|
static double D_C_pack[D_SCRATCH_DIM] __attribute__((aligned(64)));
|
|
#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called.
|
|
#define AT_MR 4 // The kernel dimension of the A transpose SYRK kernel.(AT_MR * NR).
|
|
static err_t bli_ssyrk_small
|
|
(
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
obj_t* beta,
|
|
obj_t* c,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
static err_t bli_dsyrk_small
|
|
(
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
obj_t* beta,
|
|
obj_t* c,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
static err_t bli_ssyrk_small_atbn
|
|
(
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
obj_t* beta,
|
|
obj_t* c,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
static err_t bli_dsyrk_small_atbn
|
|
(
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
obj_t* beta,
|
|
obj_t* c,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
/*
|
|
* The bli_syrk_small function will use the
|
|
* custom MRxNR kernels, to perform the computation.
|
|
* The custom kernels are used if the [M * N] < 240 * 240
|
|
*/
|
|
err_t bli_syrk_small
|
|
(
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
obj_t* beta,
|
|
obj_t* c,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
// FGVZ: This code was originally in bli_syrk_front(). However, it really
|
|
// fits more naturally here within the bli_syrk_small() function. This
|
|
// becomes a bit more obvious now that the code is here, as it contains
|
|
// cpp macros such as BLIS_SMALL_MATRIX_A_THRES_M_SYRK, which are specific
|
|
// to this implementation.
|
|
if ( bli_obj_has_trans( a ) )
|
|
{
|
|
// Continue with small implementation.
|
|
;
|
|
}
|
|
else if ( ( bli_obj_length( a ) <= BLIS_SMALL_MATRIX_A_THRES_M_SYRK &&
|
|
bli_obj_width( a ) < BLIS_SMALL_MATRIX_A_THRES_N_SYRK ) ||
|
|
( bli_obj_length( a ) < BLIS_SMALL_MATRIX_A_THRES_M_SYRK &&
|
|
bli_obj_width( a ) <= BLIS_SMALL_MATRIX_A_THRES_N_SYRK ) )
|
|
{
|
|
// Continue with small implementation.
|
|
;
|
|
}
|
|
else
|
|
{
|
|
// Reject the problem and return to large code path.
|
|
return BLIS_FAILURE;
|
|
}
|
|
|
|
#ifdef BLIS_ENABLE_MULTITHREADING
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#endif
|
|
// If alpha is zero, scale by beta and return.
|
|
if (bli_obj_equals(alpha, &BLIS_ZERO))
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
// if row major format return.
|
|
if ((bli_obj_row_stride( a ) != 1) ||
|
|
(bli_obj_row_stride( b ) != 1) ||
|
|
(bli_obj_row_stride( c ) != 1))
|
|
{
|
|
return BLIS_INVALID_ROW_STRIDE;
|
|
}
|
|
|
|
num_t dt = ((*c).info & (0x7 << 0));
|
|
|
|
if (bli_obj_has_trans( a ))
|
|
{
|
|
if (bli_obj_has_notrans( b ))
|
|
{
|
|
if (dt == BLIS_FLOAT)
|
|
{
|
|
return bli_ssyrk_small_atbn(alpha, a, b, beta, c, cntx, cntl);
|
|
}
|
|
else if (dt == BLIS_DOUBLE)
|
|
{
|
|
return bli_dsyrk_small_atbn(alpha, a, b, beta, c, cntx, cntl);
|
|
}
|
|
}
|
|
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
if (dt == BLIS_DOUBLE)
|
|
{
|
|
return bli_dsyrk_small(alpha, a, b, beta, c, cntx, cntl);
|
|
}
|
|
|
|
if (dt == BLIS_FLOAT)
|
|
{
|
|
return bli_ssyrk_small(alpha, a, b, beta, c, cntx, cntl);
|
|
}
|
|
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
};
|
|
|
|
|
|
static err_t bli_ssyrk_small
|
|
(
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
obj_t* beta,
|
|
obj_t* c,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
|
|
int M = bli_obj_length( c ); // number of rows of Matrix C
|
|
int N = bli_obj_width( c ); // number of columns of Matrix C
|
|
int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
|
|
int L = M * N;
|
|
|
|
if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES))
|
|
|| ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
|
|
{
|
|
|
|
int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
|
|
int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
|
|
int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
|
|
int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
|
|
int row_idx, col_idx, k;
|
|
int rs_matC = bli_obj_row_stride( c );
|
|
int rsc = 1;
|
|
float *A = a->buffer; // pointer to elements of Matrix A
|
|
float *B = b->buffer; // pointer to elements of Matrix B
|
|
float *C = C_pack; // pointer to elements of Matrix C
|
|
float *matCbuf = c->buffer;
|
|
|
|
float *tA = A, *tB = B, *tC = C;//, *tA_pack;
|
|
float *tA_packed; // temprorary pointer to hold packed A memory pointer
|
|
int row_idx_packed; //packed A memory row index
|
|
int lda_packed; //lda of packed A
|
|
int col_idx_start; //starting index after A matrix is packed.
|
|
dim_t tb_inc_row = 1; // row stride of matrix B
|
|
dim_t tb_inc_col = ldb; // column stride of matrix B
|
|
__m256 ymm4, ymm5, ymm6, ymm7;
|
|
__m256 ymm8, ymm9, ymm10, ymm11;
|
|
__m256 ymm12, ymm13, ymm14, ymm15;
|
|
__m256 ymm0, ymm1, ymm2, ymm3;
|
|
|
|
int n_remainder; // If the N is non multiple of 3.(N%3)
|
|
int m_remainder; // If the M is non multiple of 32.(M%32)
|
|
|
|
float *alpha_cast, *beta_cast; // alpha, beta multiples
|
|
alpha_cast = (alpha->buffer);
|
|
beta_cast = (beta->buffer);
|
|
int required_packing_A = 1;
|
|
|
|
// when N is equal to 1 call GEMV instead of SYRK
|
|
if (N == 1)
|
|
{
|
|
bli_gemv
|
|
(
|
|
alpha,
|
|
a,
|
|
b,
|
|
beta,
|
|
c
|
|
);
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
//update the pointer math if matrix B needs to be transposed.
|
|
if (bli_obj_has_trans( b ))
|
|
{
|
|
tb_inc_col = 1; //switch row and column strides
|
|
tb_inc_row = ldb;
|
|
}
|
|
|
|
if ((N <= 3) || ((MR * K) > F_SCRATCH_DIM))
|
|
{
|
|
required_packing_A = 0;
|
|
}
|
|
/*
|
|
* The computation loop runs for MRxN columns of C matrix, thus
|
|
* accessing the MRxK A matrix data and KxNR B matrix data.
|
|
* The computation is organized as inner loops of dimension MRxNR.
|
|
*/
|
|
// Process MR rows of C matrix at a time.
|
|
for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR)
|
|
{
|
|
|
|
col_idx_start = 0;
|
|
tA_packed = A;
|
|
row_idx_packed = row_idx;
|
|
lda_packed = lda;
|
|
|
|
// This is the part of the pack and compute optimization.
|
|
// During the first column iteration, we store the accessed A matrix into
|
|
// contiguous static memory. This helps to keep te A matrix in Cache and
|
|
// aviods the TLB misses.
|
|
if (required_packing_A)
|
|
{
|
|
col_idx = 0;
|
|
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
tA_packed = A_pack;
|
|
|
|
#if 0//def BLIS_ENABLE_PREFETCH
|
|
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
|
|
#endif
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm5 = _mm256_setzero_ps();
|
|
ymm6 = _mm256_setzero_ps();
|
|
ymm7 = _mm256_setzero_ps();
|
|
ymm8 = _mm256_setzero_ps();
|
|
ymm9 = _mm256_setzero_ps();
|
|
ymm10 = _mm256_setzero_ps();
|
|
ymm11 = _mm256_setzero_ps();
|
|
ymm12 = _mm256_setzero_ps();
|
|
ymm13 = _mm256_setzero_ps();
|
|
ymm14 = _mm256_setzero_ps();
|
|
ymm15 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
// This loop is processing MR x K
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
_mm256_storeu_ps(tA_packed, ymm3); // the packing of matrix A
|
|
// ymm4 += ymm0 * ymm3;
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
// ymm8 += ymm1 * ymm3;
|
|
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
|
|
// ymm12 += ymm2 * ymm3;
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
_mm256_storeu_ps(tA_packed + 8, ymm3); // the packing of matrix A
|
|
// ymm5 += ymm0 * ymm3;
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
// ymm9 += ymm1 * ymm3;
|
|
ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
|
|
// ymm13 += ymm2 * ymm3;
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 16);
|
|
_mm256_storeu_ps(tA_packed + 16, ymm3); // the packing of matrix A
|
|
// ymm6 += ymm0 * ymm3;
|
|
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
|
|
// ymm10 += ymm1 * ymm3;
|
|
ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
|
|
// ymm14 += ymm2 * ymm3;
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 24);
|
|
_mm256_storeu_ps(tA_packed + 24, ymm3); // the packing of matrix A
|
|
// ymm7 += ymm0 * ymm3;
|
|
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
|
|
// ymm11 += ymm1 * ymm3;
|
|
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
|
|
// ymm15 += ymm2 * ymm3;
|
|
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
|
|
|
|
tA += lda;
|
|
tA_packed += MR;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_ps(ymm6, ymm0);
|
|
ymm7 = _mm256_mul_ps(ymm7, ymm0);
|
|
ymm8 = _mm256_mul_ps(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_ps(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_ps(ymm10, ymm0);
|
|
ymm11 = _mm256_mul_ps(ymm11, ymm0);
|
|
ymm12 = _mm256_mul_ps(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
|
ymm15 = _mm256_mul_ps(ymm15, ymm0);
|
|
|
|
// multiply C by beta and accumulate col 1.
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
|
ymm2 = _mm256_loadu_ps(tC + 24);
|
|
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_ps(tC, ymm4);
|
|
_mm256_storeu_ps(tC + 8, ymm5);
|
|
_mm256_storeu_ps(tC + 16, ymm6);
|
|
_mm256_storeu_ps(tC + 24, ymm7);
|
|
|
|
// multiply C by beta and accumulate, col 2.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
|
ymm2 = _mm256_loadu_ps(tC + 24);
|
|
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/
|
|
_mm256_storeu_ps(tC, ymm8);
|
|
_mm256_storeu_ps(tC + 8, ymm9);
|
|
_mm256_storeu_ps(tC + 16, ymm10);
|
|
_mm256_storeu_ps(tC + 24, ymm11);
|
|
|
|
// multiply C by beta and accumulate, col 3.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
|
ymm2 = _mm256_loadu_ps(tC + 24);
|
|
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
|
|
_mm256_storeu_ps(tC, ymm12);
|
|
_mm256_storeu_ps(tC + 8, ymm13);
|
|
_mm256_storeu_ps(tC + 16, ymm14);
|
|
_mm256_storeu_ps(tC + 24, ymm15);
|
|
|
|
// modify the pointer arithematic to use packed A matrix.
|
|
col_idx_start = NR;
|
|
tA_packed = A_pack;
|
|
row_idx_packed = 0;
|
|
lda_packed = MR;
|
|
}
|
|
// Process NR columns of C matrix at a time.
|
|
for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = tA_packed + row_idx_packed;
|
|
|
|
#if 0//def BLIS_ENABLE_PREFETCH
|
|
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
|
|
#endif
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm5 = _mm256_setzero_ps();
|
|
ymm6 = _mm256_setzero_ps();
|
|
ymm7 = _mm256_setzero_ps();
|
|
ymm8 = _mm256_setzero_ps();
|
|
ymm9 = _mm256_setzero_ps();
|
|
ymm10 = _mm256_setzero_ps();
|
|
ymm11 = _mm256_setzero_ps();
|
|
ymm12 = _mm256_setzero_ps();
|
|
ymm13 = _mm256_setzero_ps();
|
|
ymm14 = _mm256_setzero_ps();
|
|
ymm15 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
// This loop is processing MR x K
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
// ymm4 += ymm0 * ymm3;
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
// ymm8 += ymm1 * ymm3;
|
|
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
|
|
// ymm12 += ymm2 * ymm3;
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
// ymm5 += ymm0 * ymm3;
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
// ymm9 += ymm1 * ymm3;
|
|
ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
|
|
// ymm13 += ymm2 * ymm3;
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 16);
|
|
// ymm6 += ymm0 * ymm3;
|
|
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
|
|
// ymm10 += ymm1 * ymm3;
|
|
ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
|
|
// ymm14 += ymm2 * ymm3;
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 24);
|
|
// ymm7 += ymm0 * ymm3;
|
|
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
|
|
// ymm11 += ymm1 * ymm3;
|
|
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
|
|
// ymm15 += ymm2 * ymm3;
|
|
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
|
|
|
|
tA += lda_packed;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_ps(ymm6, ymm0);
|
|
ymm7 = _mm256_mul_ps(ymm7, ymm0);
|
|
ymm8 = _mm256_mul_ps(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_ps(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_ps(ymm10, ymm0);
|
|
ymm11 = _mm256_mul_ps(ymm11, ymm0);
|
|
ymm12 = _mm256_mul_ps(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
|
ymm15 = _mm256_mul_ps(ymm15, ymm0);
|
|
|
|
// multiply C by beta and accumulate col 1.
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
|
ymm2 = _mm256_loadu_ps(tC + 24);
|
|
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_ps(tC, ymm4);
|
|
_mm256_storeu_ps(tC + 8, ymm5);
|
|
_mm256_storeu_ps(tC + 16, ymm6);
|
|
_mm256_storeu_ps(tC + 24, ymm7);
|
|
|
|
// multiply C by beta and accumulate, col 2.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
|
ymm2 = _mm256_loadu_ps(tC + 24);
|
|
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/
|
|
_mm256_storeu_ps(tC, ymm8);
|
|
_mm256_storeu_ps(tC + 8, ymm9);
|
|
_mm256_storeu_ps(tC + 16, ymm10);
|
|
_mm256_storeu_ps(tC + 24, ymm11);
|
|
|
|
// multiply C by beta and accumulate, col 3.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
|
ymm2 = _mm256_loadu_ps(tC + 24);
|
|
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
|
|
_mm256_storeu_ps(tC, ymm12);
|
|
_mm256_storeu_ps(tC + 8, ymm13);
|
|
_mm256_storeu_ps(tC + 16, ymm14);
|
|
_mm256_storeu_ps(tC + 24, ymm15);
|
|
|
|
}
|
|
n_remainder = N - col_idx;
|
|
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm8 = _mm256_setzero_ps();
|
|
ymm9 = _mm256_setzero_ps();
|
|
ymm10 = _mm256_setzero_ps();
|
|
ymm11 = _mm256_setzero_ps();
|
|
ymm12 = _mm256_setzero_ps();
|
|
ymm13 = _mm256_setzero_ps();
|
|
ymm14 = _mm256_setzero_ps();
|
|
ymm15 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
|
|
ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
|
|
ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 16);
|
|
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
|
|
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 24);
|
|
ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);
|
|
ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
|
|
|
|
tA += lda;
|
|
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm8 = _mm256_mul_ps(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_ps(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_ps(ymm10, ymm0);
|
|
ymm11 = _mm256_mul_ps(ymm11, ymm0);
|
|
ymm12 = _mm256_mul_ps(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
|
ymm15 = _mm256_mul_ps(ymm15, ymm0);
|
|
|
|
// multiply C by beta and accumulate, col 1.
|
|
/*ymm2 = _mm256_loadu_ps(tC + 0);
|
|
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
|
ymm2 = _mm256_loadu_ps(tC + 24);
|
|
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/
|
|
_mm256_storeu_ps(tC + 0, ymm8);
|
|
_mm256_storeu_ps(tC + 8, ymm9);
|
|
_mm256_storeu_ps(tC + 16, ymm10);
|
|
_mm256_storeu_ps(tC + 24, ymm11);
|
|
|
|
// multiply C by beta and accumulate, col 2.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
|
ymm2 = _mm256_loadu_ps(tC + 24);
|
|
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
|
|
_mm256_storeu_ps(tC, ymm12);
|
|
_mm256_storeu_ps(tC + 8, ymm13);
|
|
_mm256_storeu_ps(tC + 16, ymm14);
|
|
_mm256_storeu_ps(tC + 24, ymm15);
|
|
|
|
col_idx += 2;
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm12 = _mm256_setzero_ps();
|
|
ymm13 = _mm256_setzero_ps();
|
|
ymm14 = _mm256_setzero_ps();
|
|
ymm15 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 16);
|
|
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 24);
|
|
ymm15 = _mm256_fmadd_ps(ymm0, ymm3, ymm15);
|
|
|
|
tA += lda;
|
|
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm12 = _mm256_mul_ps(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
|
ymm15 = _mm256_mul_ps(ymm15, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC + 0);
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
|
ymm2 = _mm256_loadu_ps(tC + 24);
|
|
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
|
|
|
|
_mm256_storeu_ps(tC + 0, ymm12);
|
|
_mm256_storeu_ps(tC + 8, ymm13);
|
|
_mm256_storeu_ps(tC + 16, ymm14);
|
|
_mm256_storeu_ps(tC + 24, ymm15);
|
|
}
|
|
}
|
|
|
|
m_remainder = M - row_idx;
|
|
|
|
if (m_remainder >= 24)
|
|
{
|
|
m_remainder -= 24;
|
|
|
|
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm5 = _mm256_setzero_ps();
|
|
ymm6 = _mm256_setzero_ps();
|
|
ymm8 = _mm256_setzero_ps();
|
|
ymm9 = _mm256_setzero_ps();
|
|
ymm10 = _mm256_setzero_ps();
|
|
ymm12 = _mm256_setzero_ps();
|
|
ymm13 = _mm256_setzero_ps();
|
|
ymm14 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
// ymm4 += ymm0 * ymm3;
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
// ymm8 += ymm1 * ymm3;
|
|
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
|
|
// ymm12 += ymm2 * ymm3;
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
// ymm5 += ymm0 * ymm3;
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
// ymm9 += ymm1 * ymm3;
|
|
ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
|
|
// ymm13 += ymm2 * ymm3;
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 16);
|
|
// ymm6 += ymm0 * ymm3;
|
|
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
|
|
// ymm10 += ymm1 * ymm3;
|
|
ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
|
|
// ymm14 += ymm2 * ymm3;
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_ps(ymm6, ymm0);
|
|
ymm8 = _mm256_mul_ps(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_ps(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_ps(ymm10, ymm0);
|
|
ymm12 = _mm256_mul_ps(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/
|
|
_mm256_storeu_ps(tC, ymm4);
|
|
_mm256_storeu_ps(tC + 8, ymm5);
|
|
_mm256_storeu_ps(tC + 16, ymm6);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/
|
|
_mm256_storeu_ps(tC, ymm8);
|
|
_mm256_storeu_ps(tC + 8, ymm9);
|
|
_mm256_storeu_ps(tC + 16, ymm10);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/
|
|
_mm256_storeu_ps(tC, ymm12);
|
|
_mm256_storeu_ps(tC + 8, ymm13);
|
|
_mm256_storeu_ps(tC + 16, ymm14);
|
|
|
|
}
|
|
n_remainder = N - col_idx;
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm8 = _mm256_setzero_ps();
|
|
ymm9 = _mm256_setzero_ps();
|
|
ymm10 = _mm256_setzero_ps();
|
|
ymm12 = _mm256_setzero_ps();
|
|
ymm13 = _mm256_setzero_ps();
|
|
ymm14 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
|
|
ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
|
|
ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 16);
|
|
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
|
|
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
|
|
|
|
tA += lda;
|
|
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm8 = _mm256_mul_ps(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_ps(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_ps(ymm10, ymm0);
|
|
ymm12 = _mm256_mul_ps(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC + 0);
|
|
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/
|
|
_mm256_storeu_ps(tC + 0, ymm8);
|
|
_mm256_storeu_ps(tC + 8, ymm9);
|
|
_mm256_storeu_ps(tC + 16, ymm10);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/
|
|
_mm256_storeu_ps(tC, ymm12);
|
|
_mm256_storeu_ps(tC + 8, ymm13);
|
|
_mm256_storeu_ps(tC + 16, ymm14);
|
|
|
|
col_idx += 2;
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm12 = _mm256_setzero_ps();
|
|
ymm13 = _mm256_setzero_ps();
|
|
ymm14 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 16);
|
|
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
|
|
|
|
tA += lda;
|
|
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm12 = _mm256_mul_ps(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC + 0);
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_ps(tC + 16);
|
|
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/
|
|
|
|
_mm256_storeu_ps(tC + 0, ymm12);
|
|
_mm256_storeu_ps(tC + 8, ymm13);
|
|
_mm256_storeu_ps(tC + 16, ymm14);
|
|
}
|
|
|
|
row_idx += 24;
|
|
}
|
|
|
|
if (m_remainder >= 16)
|
|
{
|
|
m_remainder -= 16;
|
|
|
|
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm5 = _mm256_setzero_ps();
|
|
ymm6 = _mm256_setzero_ps();
|
|
ymm7 = _mm256_setzero_ps();
|
|
ymm8 = _mm256_setzero_ps();
|
|
ymm9 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
|
|
ymm8 = _mm256_fmadd_ps(ymm2, ymm3, ymm8);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_ps(ymm6, ymm0);
|
|
ymm7 = _mm256_mul_ps(ymm7, ymm0);
|
|
ymm8 = _mm256_mul_ps(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_ps(ymm9, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_ps(tC, ymm4);
|
|
_mm256_storeu_ps(tC + 8, ymm5);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_ps(tC, ymm6);
|
|
_mm256_storeu_ps(tC + 8, ymm7);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/
|
|
_mm256_storeu_ps(tC, ymm8);
|
|
_mm256_storeu_ps(tC + 8, ymm9);
|
|
|
|
}
|
|
n_remainder = N - col_idx;
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm5 = _mm256_setzero_ps();
|
|
ymm6 = _mm256_setzero_ps();
|
|
ymm7 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_ps(ymm6, ymm0);
|
|
ymm7 = _mm256_mul_ps(ymm7, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_ps(tC, ymm4);
|
|
_mm256_storeu_ps(tC + 8, ymm5);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_ps(tC, ymm6);
|
|
_mm256_storeu_ps(tC + 8, ymm7);
|
|
|
|
col_idx += 2;
|
|
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm5 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 8);
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_ps(tC + 8);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_ps(tC, ymm4);
|
|
_mm256_storeu_ps(tC + 8, ymm5);
|
|
|
|
}
|
|
|
|
row_idx += 16;
|
|
}
|
|
|
|
if (m_remainder >= 8)
|
|
{
|
|
m_remainder -= 8;
|
|
|
|
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm5 = _mm256_setzero_ps();
|
|
ymm6 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
|
|
ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_ps(ymm6, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/
|
|
_mm256_storeu_ps(tC, ymm4);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_ps(tC, ymm5);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/
|
|
_mm256_storeu_ps(tC, ymm6);
|
|
}
|
|
n_remainder = N - col_idx;
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm5 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/
|
|
_mm256_storeu_ps(tC, ymm4);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_ps(tC, ymm5);
|
|
|
|
col_idx += 2;
|
|
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm4 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_ps(tC);
|
|
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/
|
|
_mm256_storeu_ps(tC, ymm4);
|
|
|
|
}
|
|
|
|
row_idx += 8;
|
|
}
|
|
// M is not a multiple of 32.
|
|
// The handling of edge case where the remainder
|
|
// dimension is less than 8. The padding takes place
|
|
// to handle this case.
|
|
if ((m_remainder) && (lda > 7))
|
|
{
|
|
float f_temp[8];
|
|
|
|
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm5 = _mm256_setzero_ps();
|
|
ymm7 = _mm256_setzero_ps();
|
|
ymm9 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < (K - 1); ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tA[i];
|
|
}
|
|
ymm3 = _mm256_loadu_ps(f_temp);
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
|
|
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
ymm7 = _mm256_mul_ps(ymm7, ymm0);
|
|
ymm9 = _mm256_mul_ps(ymm9, ymm0);
|
|
|
|
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_ps(f_temp);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_ps(f_temp, ymm5);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
|
|
tC += ldc;
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_ps(f_temp);
|
|
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_ps(f_temp, ymm7);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
|
|
tC += ldc;
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_ps(f_temp);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/
|
|
_mm256_storeu_ps(f_temp, ymm9);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
}
|
|
n_remainder = N - col_idx;
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm5 = _mm256_setzero_ps();
|
|
ymm7 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < (K - 1); ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
|
|
|
|
tA += lda;
|
|
}
|
|
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tA[i];
|
|
}
|
|
ymm3 = _mm256_loadu_ps(f_temp);
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
|
|
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
ymm7 = _mm256_mul_ps(ymm7, ymm0);
|
|
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_ps(f_temp);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_ps(f_temp, ymm5);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
|
|
tC += ldc;
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_ps(f_temp);
|
|
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_ps(f_temp, ymm7);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm5 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; k < (K - 1); ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
|
|
tA += lda;
|
|
}
|
|
|
|
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tA[i];
|
|
}
|
|
ymm3 = _mm256_loadu_ps(f_temp);
|
|
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
|
|
|
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
|
|
// multiply C by beta and accumulate.
|
|
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
|
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_ps(f_temp);
|
|
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_ps(f_temp, ymm5);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
}
|
|
m_remainder = 0;
|
|
}
|
|
|
|
if (m_remainder)
|
|
{
|
|
float result;
|
|
for (; row_idx < M; row_idx += 1)
|
|
{
|
|
for (col_idx = 0; col_idx < N; col_idx += 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
result = 0;
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
result += (*tA) * (*tB);
|
|
tA += lda;
|
|
tB += tb_inc_row;
|
|
}
|
|
|
|
result *= (*alpha_cast);
|
|
(*tC) = /*(*tC) * (*beta_cast) + */result;
|
|
}
|
|
}
|
|
}
|
|
|
|
//copy/compute sryk values back to C using SIMD
|
|
if ( bli_seq0( *beta_cast ) )
|
|
{//just copy in case of beta = 0
|
|
dim_t _i, _j, k, _l;
|
|
if(bli_obj_is_lower(c)) // c is lower
|
|
{
|
|
//first column
|
|
_j = 0;
|
|
k = M >> 3;
|
|
_i = 0;
|
|
for ( _l = 0; _l < k; _l++ )
|
|
{
|
|
ymm0 = _mm256_loadu_ps((C + _i*rsc));
|
|
_mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0);
|
|
_i += 8;
|
|
}
|
|
while (_i < M )
|
|
{
|
|
bli_sscopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
_i++;
|
|
}
|
|
_j++;
|
|
while ( _j < N ) //next column
|
|
{
|
|
//k = (_j + (8 - (_j & 7)));
|
|
_l = _j & 7;
|
|
k = (_l != 0) ? (_j + (8 - _l)) : _j;
|
|
k = (k <= M) ? k : M;
|
|
for ( _i = _j; _i < k; ++_i )
|
|
{
|
|
bli_sscopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
k = (M - _i) >> 3;
|
|
_l = 0;
|
|
while ( _l < k )
|
|
{
|
|
ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
|
|
_mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
|
|
|
|
_i += 8;
|
|
_l++;
|
|
}
|
|
while (_i < M )
|
|
{
|
|
bli_sscopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
_i++;
|
|
}
|
|
_j++;
|
|
}
|
|
}
|
|
else //c is upper
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
{
|
|
k = (_j + 1) >> 3;
|
|
_i = 0;
|
|
_l = 0;
|
|
while ( _l < k )
|
|
{
|
|
ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
|
|
_mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
|
|
_i += 8;
|
|
_l++;
|
|
}
|
|
while (_i <= _j )
|
|
{
|
|
bli_sscopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
++_i;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{//when beta is non-zero, fmadd and store the results
|
|
dim_t _i, _j, k, _l;
|
|
ymm1 = _mm256_broadcast_ss(beta_cast);
|
|
if(bli_obj_is_lower(c)) //c is lower
|
|
{
|
|
//first column
|
|
_j = 0;
|
|
k = M >> 3;
|
|
_i = 0;
|
|
for ( _l = 0; _l < k; _l++ )
|
|
{
|
|
ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC));
|
|
ymm0 = _mm256_loadu_ps((C + _i*rsc));
|
|
ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0);
|
|
_mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0);
|
|
_i += 8;
|
|
}
|
|
while (_i < M )
|
|
{
|
|
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
_i++;
|
|
}
|
|
_j++;
|
|
while ( _j < N ) //next column
|
|
{
|
|
//k = (_j + (8 - (_j & 7)));
|
|
_l = _j & 7;
|
|
k = (_l != 0) ? (_j + (8 - _l)) : _j;
|
|
k = (k <= M) ? k : M;
|
|
for ( _i = _j; _i < k; ++_i )
|
|
{
|
|
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
k = (M - _i) >> 3;
|
|
_l = 0;
|
|
while ( _l < k )
|
|
{
|
|
ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC));
|
|
ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
|
|
ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0);
|
|
_mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
|
|
|
|
_i += 8;
|
|
_l++;
|
|
}
|
|
while (_i < M )
|
|
{
|
|
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
_i++;
|
|
}
|
|
_j++;
|
|
}
|
|
}
|
|
else //c is upper
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
{
|
|
k = (_j + 1) >> 3;
|
|
_i = 0;
|
|
_l = 0;
|
|
while ( _l < k )
|
|
{
|
|
ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC));
|
|
ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
|
|
ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0);
|
|
_mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
|
|
_i += 8;
|
|
_l++;
|
|
}
|
|
while (_i <= _j )
|
|
{
|
|
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
++_i;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return BLIS_SUCCESS;
|
|
}
|
|
else
|
|
return BLIS_NONCONFORMAL_DIMENSIONS;
|
|
|
|
|
|
};
|
|
|
|
static err_t bli_dsyrk_small
|
|
(
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
obj_t* beta,
|
|
obj_t* c,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
|
|
int M = bli_obj_length( c ); // number of rows of Matrix C
|
|
int N = bli_obj_width( c ); // number of columns of Matrix C
|
|
int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
|
|
int L = M * N;
|
|
|
|
// If alpha is zero, scale by beta and return.
|
|
if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES))
|
|
|| ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
|
|
{
|
|
|
|
int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
|
|
int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
|
|
int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
|
|
int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
|
|
int row_idx, col_idx, k;
|
|
int rs_matC = bli_obj_row_stride( c );
|
|
int rsc = 1;
|
|
double *A = a->buffer; // pointer to elements of Matrix A
|
|
double *B = b->buffer; // pointer to elements of Matrix B
|
|
double *C = D_C_pack; // pointer to elements of Matrix C
|
|
double *matCbuf = c->buffer;
|
|
|
|
double *tA = A, *tB = B, *tC = C;//, *tA_pack;
|
|
double *tA_packed; // temprorary pointer to hold packed A memory pointer
|
|
int row_idx_packed; //packed A memory row index
|
|
int lda_packed; //lda of packed A
|
|
int col_idx_start; //starting index after A matrix is packed.
|
|
dim_t tb_inc_row = 1; // row stride of matrix B
|
|
dim_t tb_inc_col = ldb; // column stride of matrix B
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
|
|
int n_remainder; // If the N is non multiple of 3.(N%3)
|
|
int m_remainder; // If the M is non multiple of 16.(M%16)
|
|
|
|
double *alpha_cast, *beta_cast; // alpha, beta multiples
|
|
alpha_cast = (alpha->buffer);
|
|
beta_cast = (beta->buffer);
|
|
int required_packing_A = 1;
|
|
|
|
// when N is equal to 1 call GEMV instead of SYRK
|
|
if (N == 1)
|
|
{
|
|
bli_gemv
|
|
(
|
|
alpha,
|
|
a,
|
|
b,
|
|
beta,
|
|
c
|
|
);
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
//update the pointer math if matrix B needs to be transposed.
|
|
if (bli_obj_has_trans( b ))
|
|
{
|
|
tb_inc_col = 1; //switch row and column strides
|
|
tb_inc_row = ldb;
|
|
}
|
|
|
|
if ((N <= 3) || ((D_MR * K) > D_SCRATCH_DIM))
|
|
{
|
|
required_packing_A = 0;
|
|
}
|
|
/*
|
|
* The computation loop runs for D_MRxN columns of C matrix, thus
|
|
* accessing the D_MRxK A matrix data and KxNR B matrix data.
|
|
* The computation is organized as inner loops of dimension D_MRxNR.
|
|
*/
|
|
// Process D_MR rows of C matrix at a time.
|
|
for (row_idx = 0; (row_idx + (D_MR - 1)) < M; row_idx += D_MR)
|
|
{
|
|
|
|
col_idx_start = 0;
|
|
tA_packed = A;
|
|
row_idx_packed = row_idx;
|
|
lda_packed = lda;
|
|
|
|
// This is the part of the pack and compute optimization.
|
|
// During the first column iteration, we store the accessed A matrix into
|
|
// contiguous static memory. This helps to keep te A matrix in Cache and
|
|
// aviods the TLB misses.
|
|
if (required_packing_A)
|
|
{
|
|
col_idx = 0;
|
|
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
tA_packed = D_A_pack;
|
|
|
|
#if 0//def BLIS_ENABLE_PREFETCH
|
|
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
|
|
#endif
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
// This loop is processing D_MR x K
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
_mm256_storeu_pd(tA_packed, ymm3); // the packing of matrix A
|
|
// ymm4 += ymm0 * ymm3;
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
// ymm8 += ymm1 * ymm3;
|
|
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
|
|
// ymm12 += ymm2 * ymm3;
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
_mm256_storeu_pd(tA_packed + 4, ymm3); // the packing of matrix A
|
|
// ymm5 += ymm0 * ymm3;
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
// ymm9 += ymm1 * ymm3;
|
|
ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
|
|
// ymm13 += ymm2 * ymm3;
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 8);
|
|
_mm256_storeu_pd(tA_packed + 8, ymm3); // the packing of matrix A
|
|
// ymm6 += ymm0 * ymm3;
|
|
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
|
|
// ymm10 += ymm1 * ymm3;
|
|
ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
|
|
// ymm14 += ymm2 * ymm3;
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 12);
|
|
_mm256_storeu_pd(tA_packed + 12, ymm3); // the packing of matrix A
|
|
// ymm7 += ymm0 * ymm3;
|
|
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
|
|
// ymm11 += ymm1 * ymm3;
|
|
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
|
|
// ymm15 += ymm2 * ymm3;
|
|
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
|
|
|
|
tA += lda;
|
|
tA_packed += D_MR;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_pd(ymm6, ymm0);
|
|
ymm7 = _mm256_mul_pd(ymm7, ymm0);
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0);
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
// multiply C by beta and accumulate col 1.
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
|
|
ymm2 = _mm256_loadu_pd(tC + 12);
|
|
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_pd(tC, ymm4);
|
|
_mm256_storeu_pd(tC + 4, ymm5);
|
|
_mm256_storeu_pd(tC + 8, ymm6);
|
|
_mm256_storeu_pd(tC + 12, ymm7);
|
|
|
|
// multiply C by beta and accumulate, col 2.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
|
|
ymm2 = _mm256_loadu_pd(tC + 12);
|
|
ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/
|
|
_mm256_storeu_pd(tC, ymm8);
|
|
_mm256_storeu_pd(tC + 4, ymm9);
|
|
_mm256_storeu_pd(tC + 8, ymm10);
|
|
_mm256_storeu_pd(tC + 12, ymm11);
|
|
|
|
// multiply C by beta and accumulate, col 3.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
|
|
ymm2 = _mm256_loadu_pd(tC + 12);
|
|
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
|
|
_mm256_storeu_pd(tC, ymm12);
|
|
_mm256_storeu_pd(tC + 4, ymm13);
|
|
_mm256_storeu_pd(tC + 8, ymm14);
|
|
_mm256_storeu_pd(tC + 12, ymm15);
|
|
|
|
// modify the pointer arithematic to use packed A matrix.
|
|
col_idx_start = NR;
|
|
tA_packed = D_A_pack;
|
|
row_idx_packed = 0;
|
|
lda_packed = D_MR;
|
|
}
|
|
// Process NR columns of C matrix at a time.
|
|
for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = tA_packed + row_idx_packed;
|
|
|
|
#if 0//def BLIS_ENABLE_PREFETCH
|
|
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
|
|
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
|
|
#endif
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
// This loop is processing D_MR x K
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
// ymm4 += ymm0 * ymm3;
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
// ymm8 += ymm1 * ymm3;
|
|
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
|
|
// ymm12 += ymm2 * ymm3;
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
// ymm5 += ymm0 * ymm3;
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
// ymm9 += ymm1 * ymm3;
|
|
ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
|
|
// ymm13 += ymm2 * ymm3;
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 8);
|
|
// ymm6 += ymm0 * ymm3;
|
|
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
|
|
// ymm10 += ymm1 * ymm3;
|
|
ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
|
|
// ymm14 += ymm2 * ymm3;
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 12);
|
|
// ymm7 += ymm0 * ymm3;
|
|
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
|
|
// ymm11 += ymm1 * ymm3;
|
|
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
|
|
// ymm15 += ymm2 * ymm3;
|
|
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
|
|
|
|
tA += lda_packed;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_pd(ymm6, ymm0);
|
|
ymm7 = _mm256_mul_pd(ymm7, ymm0);
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0);
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
// multiply C by beta and accumulate col 1.
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
|
|
ymm2 = _mm256_loadu_pd(tC + 12);
|
|
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_pd(tC, ymm4);
|
|
_mm256_storeu_pd(tC + 4, ymm5);
|
|
_mm256_storeu_pd(tC + 8, ymm6);
|
|
_mm256_storeu_pd(tC + 12, ymm7);
|
|
|
|
// multiply C by beta and accumulate, col 2.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
|
|
ymm2 = _mm256_loadu_pd(tC + 12);
|
|
ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/
|
|
_mm256_storeu_pd(tC, ymm8);
|
|
_mm256_storeu_pd(tC + 4, ymm9);
|
|
_mm256_storeu_pd(tC + 8, ymm10);
|
|
_mm256_storeu_pd(tC + 12, ymm11);
|
|
|
|
// multiply C by beta and accumulate, col 3.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
|
|
ymm2 = _mm256_loadu_pd(tC + 12);
|
|
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
|
|
_mm256_storeu_pd(tC, ymm12);
|
|
_mm256_storeu_pd(tC + 4, ymm13);
|
|
_mm256_storeu_pd(tC + 8, ymm14);
|
|
_mm256_storeu_pd(tC + 12, ymm15);
|
|
|
|
}
|
|
n_remainder = N - col_idx;
|
|
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
|
|
ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 8);
|
|
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
|
|
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 12);
|
|
ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11);
|
|
ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);
|
|
|
|
tA += lda;
|
|
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0);
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
// multiply C by beta and accumulate, col 1.
|
|
/*ymm2 = _mm256_loadu_pd(tC + 0);
|
|
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
|
|
ymm2 = _mm256_loadu_pd(tC + 12);
|
|
ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/
|
|
_mm256_storeu_pd(tC + 0, ymm8);
|
|
_mm256_storeu_pd(tC + 4, ymm9);
|
|
_mm256_storeu_pd(tC + 8, ymm10);
|
|
_mm256_storeu_pd(tC + 12, ymm11);
|
|
|
|
// multiply C by beta and accumulate, col 2.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
|
|
ymm2 = _mm256_loadu_pd(tC + 12);
|
|
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
|
|
_mm256_storeu_pd(tC, ymm12);
|
|
_mm256_storeu_pd(tC + 4, ymm13);
|
|
_mm256_storeu_pd(tC + 8, ymm14);
|
|
_mm256_storeu_pd(tC + 12, ymm15);
|
|
|
|
col_idx += 2;
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 8);
|
|
ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 12);
|
|
ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15);
|
|
|
|
tA += lda;
|
|
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC + 0);
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
|
|
ymm2 = _mm256_loadu_pd(tC + 12);
|
|
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
|
|
|
|
_mm256_storeu_pd(tC + 0, ymm12);
|
|
_mm256_storeu_pd(tC + 4, ymm13);
|
|
_mm256_storeu_pd(tC + 8, ymm14);
|
|
_mm256_storeu_pd(tC + 12, ymm15);
|
|
}
|
|
}
|
|
|
|
m_remainder = M - row_idx;
|
|
|
|
if (m_remainder >= 12)
|
|
{
|
|
m_remainder -= 12;
|
|
|
|
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
// ymm4 += ymm0 * ymm3;
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
// ymm8 += ymm1 * ymm3;
|
|
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
|
|
// ymm12 += ymm2 * ymm3;
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
// ymm5 += ymm0 * ymm3;
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
// ymm9 += ymm1 * ymm3;
|
|
ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
|
|
// ymm13 += ymm2 * ymm3;
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 8);
|
|
// ymm6 += ymm0 * ymm3;
|
|
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
|
|
// ymm10 += ymm1 * ymm3;
|
|
ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
|
|
// ymm14 += ymm2 * ymm3;
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_pd(ymm6, ymm0);
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/
|
|
_mm256_storeu_pd(tC, ymm4);
|
|
_mm256_storeu_pd(tC + 4, ymm5);
|
|
_mm256_storeu_pd(tC + 8, ymm6);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/
|
|
_mm256_storeu_pd(tC, ymm8);
|
|
_mm256_storeu_pd(tC + 4, ymm9);
|
|
_mm256_storeu_pd(tC + 8, ymm10);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/
|
|
_mm256_storeu_pd(tC, ymm12);
|
|
_mm256_storeu_pd(tC + 4, ymm13);
|
|
_mm256_storeu_pd(tC + 8, ymm14);
|
|
|
|
}
|
|
n_remainder = N - col_idx;
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
|
|
ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 8);
|
|
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
|
|
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
|
|
|
|
tA += lda;
|
|
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC + 0);
|
|
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/
|
|
_mm256_storeu_pd(tC + 0, ymm8);
|
|
_mm256_storeu_pd(tC + 4, ymm9);
|
|
_mm256_storeu_pd(tC + 8, ymm10);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/
|
|
_mm256_storeu_pd(tC, ymm12);
|
|
_mm256_storeu_pd(tC + 4, ymm13);
|
|
_mm256_storeu_pd(tC + 8, ymm14);
|
|
|
|
col_idx += 2;
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 8);
|
|
ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
|
|
|
|
tA += lda;
|
|
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0);
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC + 0);
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
|
|
ymm2 = _mm256_loadu_pd(tC + 8);
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/
|
|
|
|
_mm256_storeu_pd(tC + 0, ymm12);
|
|
_mm256_storeu_pd(tC + 4, ymm13);
|
|
_mm256_storeu_pd(tC + 8, ymm14);
|
|
}
|
|
|
|
row_idx += 12;
|
|
}
|
|
|
|
if (m_remainder >= 8)
|
|
{
|
|
m_remainder -= 8;
|
|
|
|
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
|
|
ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_pd(ymm6, ymm0);
|
|
ymm7 = _mm256_mul_pd(ymm7, ymm0);
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0);
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_pd(tC, ymm4);
|
|
_mm256_storeu_pd(tC + 4, ymm5);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_pd(tC, ymm6);
|
|
_mm256_storeu_pd(tC + 4, ymm7);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/
|
|
_mm256_storeu_pd(tC, ymm8);
|
|
_mm256_storeu_pd(tC + 4, ymm9);
|
|
|
|
}
|
|
n_remainder = N - col_idx;
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_pd(ymm6, ymm0);
|
|
ymm7 = _mm256_mul_pd(ymm7, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_pd(tC, ymm4);
|
|
_mm256_storeu_pd(tC + 4, ymm5);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_pd(tC, ymm6);
|
|
_mm256_storeu_pd(tC + 4, ymm7);
|
|
|
|
col_idx += 2;
|
|
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 4);
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
|
|
ymm2 = _mm256_loadu_pd(tC + 4);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_pd(tC, ymm4);
|
|
_mm256_storeu_pd(tC + 4, ymm5);
|
|
|
|
}
|
|
|
|
row_idx += 8;
|
|
}
|
|
|
|
if (m_remainder >= 4)
|
|
{
|
|
m_remainder -= 4;
|
|
|
|
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
|
|
ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
ymm6 = _mm256_mul_pd(ymm6, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/
|
|
_mm256_storeu_pd(tC, ymm4);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_pd(tC, ymm5);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/
|
|
_mm256_storeu_pd(tC, ymm6);
|
|
}
|
|
n_remainder = N - col_idx;
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm0);
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/
|
|
_mm256_storeu_pd(tC, ymm4);
|
|
|
|
// multiply C by beta and accumulate.
|
|
tC += ldc;
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_pd(tC, ymm5);
|
|
|
|
col_idx += 2;
|
|
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm0);
|
|
|
|
// multiply C by beta and accumulate.
|
|
/*ymm2 = _mm256_loadu_pd(tC);
|
|
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/
|
|
_mm256_storeu_pd(tC, ymm4);
|
|
|
|
}
|
|
|
|
row_idx += 4;
|
|
}
|
|
// M is not a multiple of 32.
|
|
// The handling of edge case where the remainder
|
|
// dimension is less than 8. The padding takes place
|
|
// to handle this case.
|
|
if ((m_remainder) && (lda > 3))
|
|
{
|
|
double f_temp[8];
|
|
|
|
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
// clear scratch registers.
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < (K - 1); ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
//broadcasted matrix B elements are multiplied
|
|
//with matrix A columns.
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
|
|
|
|
tA += lda;
|
|
}
|
|
// alpha, beta multiplication.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
|
|
tB += tb_inc_row;
|
|
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tA[i];
|
|
}
|
|
ymm3 = _mm256_loadu_pd(f_temp);
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
|
|
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
//multiply A*B by alpha.
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
ymm7 = _mm256_mul_pd(ymm7, ymm0);
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
|
|
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_pd(f_temp);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_pd(f_temp, ymm5);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
|
|
tC += ldc;
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_pd(f_temp);
|
|
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_pd(f_temp, ymm7);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
|
|
tC += ldc;
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_pd(f_temp);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/
|
|
_mm256_storeu_pd(f_temp, ymm9);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
}
|
|
n_remainder = N - col_idx;
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 2)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < (K - 1); ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
|
|
|
|
tA += lda;
|
|
}
|
|
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
|
|
tB += tb_inc_row;
|
|
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tA[i];
|
|
}
|
|
ymm3 = _mm256_loadu_pd(f_temp);
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
|
|
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
ymm7 = _mm256_mul_pd(ymm7, ymm0);
|
|
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_pd(f_temp);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_pd(f_temp, ymm5);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
|
|
tC += ldc;
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_pd(f_temp);
|
|
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
|
|
_mm256_storeu_pd(f_temp, ymm7);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
}
|
|
// if the N is not multiple of 3.
|
|
// handling edge case.
|
|
if (n_remainder == 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
ymm5 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; k < (K - 1); ++k)
|
|
{
|
|
// The inner loop broadcasts the B matrix data and
|
|
// multiplies it with the A matrix.
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
|
|
tA += lda;
|
|
}
|
|
|
|
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
|
|
tB += tb_inc_row;
|
|
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tA[i];
|
|
}
|
|
ymm3 = _mm256_loadu_pd(f_temp);
|
|
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
|
|
|
|
ymm0 = _mm256_broadcast_sd(alpha_cast);
|
|
//ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
|
|
// multiply C by beta and accumulate.
|
|
ymm5 = _mm256_mul_pd(ymm5, ymm0);
|
|
|
|
/*for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
f_temp[i] = tC[i];
|
|
}
|
|
ymm2 = _mm256_loadu_pd(f_temp);
|
|
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
|
|
_mm256_storeu_pd(f_temp, ymm5);
|
|
for (int i = 0; i < m_remainder; i++)
|
|
{
|
|
tC[i] = f_temp[i];
|
|
}
|
|
}
|
|
m_remainder = 0;
|
|
}
|
|
|
|
if (m_remainder)
|
|
{
|
|
double result;
|
|
for (; row_idx < M; row_idx += 1)
|
|
{
|
|
for (col_idx = 0; col_idx < N; col_idx += 1)
|
|
{
|
|
//pointer math to point to proper memory
|
|
tC = C + ldc * col_idx + row_idx;
|
|
tB = B + tb_inc_col * col_idx;
|
|
tA = A + row_idx;
|
|
|
|
result = 0;
|
|
for (k = 0; k < K; ++k)
|
|
{
|
|
result += (*tA) * (*tB);
|
|
tA += lda;
|
|
tB += tb_inc_row;
|
|
}
|
|
|
|
result *= (*alpha_cast);
|
|
(*tC) = /*(*tC) * (*beta_cast) + */result;
|
|
}
|
|
}
|
|
}
|
|
|
|
//copy/compute sryk values back to C using SIMD
|
|
if ( bli_seq0( *beta_cast ) )
|
|
{//just copy for beta = 0
|
|
dim_t _i, _j, k, _l;
|
|
if(bli_obj_is_lower(c)) //c is lower
|
|
{
|
|
//first column
|
|
_j = 0;
|
|
k = M >> 2;
|
|
_i = 0;
|
|
for ( _l = 0; _l < k; _l++ )
|
|
{
|
|
ymm0 = _mm256_loadu_pd((C + _i*rsc));
|
|
_mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0);
|
|
_i += 4;
|
|
}
|
|
while (_i < M )
|
|
{
|
|
bli_ddcopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
_i++;
|
|
}
|
|
_j++;
|
|
while ( _j < N ) //next column
|
|
{
|
|
//k = (_j + (4 - (_j & 3)));
|
|
_l = _j & 3;
|
|
k = (_l != 0) ? (_j + (4 - _l)) : _j;
|
|
k = (k <= M) ? k : M;
|
|
for ( _i = _j; _i < k; ++_i )
|
|
{
|
|
bli_ddcopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
k = (M - _i) >> 2;
|
|
_l = 0;
|
|
while ( _l < k )
|
|
{
|
|
ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
|
|
_mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
|
|
|
|
_i += 4;
|
|
_l++;
|
|
}
|
|
while (_i < M )
|
|
{
|
|
bli_ddcopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
_i++;
|
|
}
|
|
_j++;
|
|
}
|
|
}
|
|
else //c is upper
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
{
|
|
k = (_j + 1) >> 2;
|
|
_i = 0;
|
|
_l = 0;
|
|
while ( _l < k )
|
|
{
|
|
ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
|
|
_mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
|
|
_i += 4;
|
|
_l++;
|
|
}
|
|
while (_i <= _j )
|
|
{
|
|
bli_ddcopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
++_i;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{//when beta is non-zero, fmadd and store the results
|
|
dim_t _i, _j, k, _l;
|
|
ymm1 = _mm256_broadcast_sd(beta_cast);
|
|
if(bli_obj_is_lower(c)) //c is lower
|
|
{
|
|
//first column
|
|
_j = 0;
|
|
k = M >> 2;
|
|
_i = 0;
|
|
for ( _l = 0; _l < k; _l++ )
|
|
{
|
|
ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC));
|
|
ymm0 = _mm256_loadu_pd((C + _i*rsc));
|
|
ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0);
|
|
_mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0);
|
|
_i += 4;
|
|
}
|
|
while (_i < M )
|
|
{
|
|
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
_i++;
|
|
}
|
|
_j++;
|
|
while ( _j < N ) //next column
|
|
{
|
|
//k = (_j + (4 - (_j & 3)));
|
|
_l = _j & 3;
|
|
k = (_l != 0) ? (_j + (4 - _l)) : _j;
|
|
k = (k <= M) ? k : M;
|
|
for ( _i = _j; _i < k; ++_i )
|
|
{
|
|
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
k = (M - _i) >> 2;
|
|
_l = 0;
|
|
while ( _l < k )
|
|
{
|
|
ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC));
|
|
ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
|
|
ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0);
|
|
_mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
|
|
|
|
_i += 4;
|
|
_l++;
|
|
}
|
|
while (_i < M )
|
|
{
|
|
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
_i++;
|
|
}
|
|
_j++;
|
|
}
|
|
}
|
|
else //c is upper
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
{
|
|
k = (_j + 1) >> 2;
|
|
_i = 0;
|
|
_l = 0;
|
|
while ( _l < k )
|
|
{
|
|
ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC));
|
|
ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
|
|
ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0);
|
|
_mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
|
|
_i += 4;
|
|
_l++;
|
|
}
|
|
while (_i <= _j )
|
|
{
|
|
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
++_i;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return BLIS_SUCCESS;
|
|
}
|
|
else
|
|
return BLIS_NONCONFORMAL_DIMENSIONS;
|
|
|
|
|
|
};
|
|
|
|
static err_t bli_ssyrk_small_atbn
|
|
(
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
obj_t* beta,
|
|
obj_t* c,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
int M = bli_obj_length(c); // number of rows of Matrix C
|
|
int N = bli_obj_width(c); // number of columns of Matrix C
|
|
int K = bli_obj_length(b); // number of rows of Matrix B
|
|
int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
|
|
int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
|
|
int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
|
|
int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
|
|
int row_idx = 0, col_idx = 0, k;
|
|
int rs_matC = bli_obj_row_stride( c );
|
|
int rsc = 1;
|
|
float *A = a->buffer; // pointer to matrix A elements, stored in row major format
|
|
float *B = b->buffer; // pointer to matrix B elements, stored in column major format
|
|
float *C = C_pack; // pointer to matrix C elements, stored in column major format
|
|
float *matCbuf = c->buffer;
|
|
|
|
float *tA = A, *tB = B, *tC = C;
|
|
|
|
__m256 ymm4, ymm5, ymm6, ymm7;
|
|
__m256 ymm8, ymm9, ymm10, ymm11;
|
|
__m256 ymm12, ymm13, ymm14, ymm15;
|
|
__m256 ymm0, ymm1, ymm2, ymm3;
|
|
|
|
float result, scratch[8];
|
|
float *alpha_cast, *beta_cast; // alpha, beta multiples
|
|
alpha_cast = (alpha->buffer);
|
|
beta_cast = (beta->buffer);
|
|
|
|
// The non-copy version of the A^T SYRK gives better performance for the small M cases.
|
|
// The threshold is controlled by BLIS_ATBN_M_THRES
|
|
if (M <= BLIS_ATBN_M_THRES)
|
|
{
|
|
for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
|
|
{
|
|
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
|
|
{
|
|
tA = A + row_idx * lda;
|
|
tB = B + col_idx * ldb;
|
|
tC = C + col_idx * ldc + row_idx;
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm5 = _mm256_setzero_ps();
|
|
ymm6 = _mm256_setzero_ps();
|
|
ymm7 = _mm256_setzero_ps();
|
|
ymm8 = _mm256_setzero_ps();
|
|
ymm9 = _mm256_setzero_ps();
|
|
ymm10 = _mm256_setzero_ps();
|
|
ymm11 = _mm256_setzero_ps();
|
|
ymm12 = _mm256_setzero_ps();
|
|
ymm13 = _mm256_setzero_ps();
|
|
ymm14 = _mm256_setzero_ps();
|
|
ymm15 = _mm256_setzero_ps();
|
|
|
|
//The inner loop computes the 4x3 values of the matrix.
|
|
//The computation pattern is:
|
|
// ymm4 ymm5 ymm6
|
|
// ymm7 ymm8 ymm9
|
|
// ymm10 ymm11 ymm12
|
|
// ymm13 ymm14 ymm15
|
|
|
|
//The Dot operation is performed in the inner loop, 8 float elements fit
|
|
//in the YMM register hence loop count incremented by 8
|
|
for (k = 0; (k + 7) < K; k += 8)
|
|
{
|
|
ymm0 = _mm256_loadu_ps(tB + 0);
|
|
ymm1 = _mm256_loadu_ps(tB + ldb);
|
|
ymm2 = _mm256_loadu_ps(tB + 2 * ldb);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
|
|
ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + lda);
|
|
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
|
|
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 2 * lda);
|
|
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
|
|
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 3 * lda);
|
|
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
|
|
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
|
|
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
|
|
|
|
tA += 8;
|
|
tB += 8;
|
|
|
|
}
|
|
|
|
// if K is not a multiple of 8, padding is done before load using temproary array.
|
|
if (k < K)
|
|
{
|
|
int iter;
|
|
float data_feeder[8] = { 0.0 };
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
|
|
ymm0 = _mm256_loadu_ps(data_feeder);
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
|
|
ymm1 = _mm256_loadu_ps(data_feeder);
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
|
|
ymm2 = _mm256_loadu_ps(data_feeder);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
|
|
ymm3 = _mm256_loadu_ps(data_feeder);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
|
|
ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
|
|
ymm3 = _mm256_loadu_ps(data_feeder);
|
|
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
|
|
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
|
|
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
|
|
ymm3 = _mm256_loadu_ps(data_feeder);
|
|
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
|
|
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
|
|
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
|
|
ymm3 = _mm256_loadu_ps(data_feeder);
|
|
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
|
|
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
|
|
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
|
|
|
|
}
|
|
|
|
//horizontal addition and storage of the data.
|
|
//Results for 4x3 blocks of C is stored here
|
|
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
|
|
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
|
|
_mm256_storeu_ps(scratch, ymm4);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
|
|
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
|
|
_mm256_storeu_ps(scratch, ymm7);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[1] = result/* + tC[1] * (*beta_cast)*/;
|
|
|
|
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
|
|
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
|
|
_mm256_storeu_ps(scratch, ymm10);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[2] = result/* + tC[2] * (*beta_cast)*/;
|
|
|
|
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
|
|
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
|
|
_mm256_storeu_ps(scratch, ymm13);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[3] = result/* + tC[3] * (*beta_cast)*/;
|
|
|
|
tC += ldc;
|
|
ymm5 = _mm256_hadd_ps(ymm5, ymm5);
|
|
ymm5 = _mm256_hadd_ps(ymm5, ymm5);
|
|
_mm256_storeu_ps(scratch, ymm5);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
|
|
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
|
|
_mm256_storeu_ps(scratch, ymm8);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[1] = result/* + tC[1] * (*beta_cast)*/;
|
|
|
|
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
|
|
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
|
|
_mm256_storeu_ps(scratch, ymm11);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[2] = result/* + tC[2] * (*beta_cast)*/;
|
|
|
|
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
|
|
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
|
|
_mm256_storeu_ps(scratch, ymm14);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[3] = result/* + tC[3] * (*beta_cast)*/;
|
|
|
|
tC += ldc;
|
|
ymm6 = _mm256_hadd_ps(ymm6, ymm6);
|
|
ymm6 = _mm256_hadd_ps(ymm6, ymm6);
|
|
_mm256_storeu_ps(scratch, ymm6);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
|
|
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
|
|
_mm256_storeu_ps(scratch, ymm9);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[1] = result/* + tC[1] * (*beta_cast)*/;
|
|
|
|
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
|
|
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
|
|
_mm256_storeu_ps(scratch, ymm12);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[2] = result/* + tC[2] * (*beta_cast)*/;
|
|
|
|
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
|
|
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
|
|
_mm256_storeu_ps(scratch, ymm15);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[3] = result/* + tC[3] * (*beta_cast)*/;
|
|
}
|
|
}
|
|
|
|
int processed_col = col_idx;
|
|
int processed_row = row_idx;
|
|
|
|
//The edge case handling where N is not a multiple of 3
|
|
if (processed_col < N)
|
|
{
|
|
for (col_idx = processed_col; col_idx < N; col_idx += 1)
|
|
{
|
|
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
|
|
{
|
|
tA = A + row_idx * lda;
|
|
tB = B + col_idx * ldb;
|
|
tC = C + col_idx * ldc + row_idx;
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_ps();
|
|
ymm7 = _mm256_setzero_ps();
|
|
ymm10 = _mm256_setzero_ps();
|
|
ymm13 = _mm256_setzero_ps();
|
|
|
|
//The inner loop computes the 4x1 values of the matrix.
|
|
//The computation pattern is:
|
|
// ymm4
|
|
// ymm7
|
|
// ymm10
|
|
// ymm13
|
|
|
|
for (k = 0; (k + 7) < K; k += 8)
|
|
{
|
|
ymm0 = _mm256_loadu_ps(tB + 0);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + lda);
|
|
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 2 * lda);
|
|
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
|
|
|
|
ymm3 = _mm256_loadu_ps(tA + 3 * lda);
|
|
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
|
|
|
|
tA += 8;
|
|
tB += 8;
|
|
}
|
|
|
|
// if K is not a multiple of 8, padding is done before load using temproary array.
|
|
if (k < K)
|
|
{
|
|
int iter;
|
|
float data_feeder[8] = { 0.0 };
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
|
|
ymm0 = _mm256_loadu_ps(data_feeder);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
|
|
ymm3 = _mm256_loadu_ps(data_feeder);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
|
|
ymm3 = _mm256_loadu_ps(data_feeder);
|
|
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
|
|
ymm3 = _mm256_loadu_ps(data_feeder);
|
|
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
|
|
ymm3 = _mm256_loadu_ps(data_feeder);
|
|
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
|
|
|
|
}
|
|
|
|
//horizontal addition and storage of the data.
|
|
//Results for 4x1 blocks of C is stored here
|
|
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
|
|
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
|
|
_mm256_storeu_ps(scratch, ymm4);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
|
|
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
|
|
_mm256_storeu_ps(scratch, ymm7);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[1] = result/* + tC[1] * (*beta_cast)*/;
|
|
|
|
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
|
|
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
|
|
_mm256_storeu_ps(scratch, ymm10);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[2] = result/* + tC[2] * (*beta_cast)*/;
|
|
|
|
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
|
|
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
|
|
_mm256_storeu_ps(scratch, ymm13);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[3] = result/* + tC[3] * (*beta_cast)*/;
|
|
|
|
}
|
|
}
|
|
processed_row = row_idx;
|
|
}
|
|
|
|
//The edge case handling where M is not a multiple of 4
|
|
if (processed_row < M)
|
|
{
|
|
for (row_idx = processed_row; row_idx < M; row_idx += 1)
|
|
{
|
|
for (col_idx = 0; col_idx < N; col_idx += 1)
|
|
{
|
|
tA = A + row_idx * lda;
|
|
tB = B + col_idx * ldb;
|
|
tC = C + col_idx * ldc + row_idx;
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_ps();
|
|
|
|
for (k = 0; (k + 7) < K; k += 8)
|
|
{
|
|
ymm0 = _mm256_loadu_ps(tB + 0);
|
|
ymm3 = _mm256_loadu_ps(tA);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
|
|
tA += 8;
|
|
tB += 8;
|
|
}
|
|
|
|
// if K is not a multiple of 8, padding is done before load using temproary array.
|
|
if (k < K)
|
|
{
|
|
int iter;
|
|
float data_feeder[8] = { 0.0 };
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
|
|
ymm0 = _mm256_loadu_ps(data_feeder);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
|
|
ymm3 = _mm256_loadu_ps(data_feeder);
|
|
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
|
|
|
|
}
|
|
|
|
//horizontal addition and storage of the data.
|
|
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
|
|
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
|
|
_mm256_storeu_ps(scratch, ymm4);
|
|
result = scratch[0] + scratch[4];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
//copy/compute sryk values back to C
|
|
if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C
|
|
{
|
|
dim_t _i, _j;
|
|
if(bli_obj_is_lower(c)) //c is lower
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
for ( _i = 0; _i < M; ++_i )
|
|
if ( (doff_t)_j - (doff_t)_i <= 0 )
|
|
{
|
|
bli_sscopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
}
|
|
else //c is upper
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
for ( _i = 0; _i < M; ++_i )
|
|
if ( (doff_t)_j - (doff_t)_i >= 0 )
|
|
{
|
|
bli_sscopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
}
|
|
}
|
|
else //when beta is non-zero, multiply and store result to C
|
|
{
|
|
dim_t _i, _j;
|
|
if(bli_obj_is_lower(c)) //c is lower
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
for ( _i = 0; _i < M; ++_i )
|
|
if ( (doff_t)_j - (doff_t)_i <= 0 )
|
|
{
|
|
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
}
|
|
else //c is upper
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
for ( _i = 0; _i < M; ++_i )
|
|
if ( (doff_t)_j - (doff_t)_i >= 0 )
|
|
{
|
|
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
}
|
|
}
|
|
|
|
return BLIS_SUCCESS;
|
|
}
|
|
else
|
|
return BLIS_NONCONFORMAL_DIMENSIONS;
|
|
}
|
|
|
|
static err_t bli_dsyrk_small_atbn
|
|
(
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
obj_t* beta,
|
|
obj_t* c,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
int M = bli_obj_length( c ); // number of rows of Matrix C
|
|
int N = bli_obj_width( c ); // number of columns of Matrix C
|
|
int K = bli_obj_length( b ); // number of rows of Matrix B
|
|
int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
|
|
int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
|
|
int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
|
|
int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
|
|
int row_idx = 0, col_idx = 0, k;
|
|
int rs_matC = bli_obj_row_stride( c );
|
|
int rsc = 1;
|
|
double *A = a->buffer; // pointer to matrix A elements, stored in row major format
|
|
double *B = b->buffer; // pointer to matrix B elements, stored in column major format
|
|
double *C = D_C_pack; // pointer to matrix C elements, stored in column major format
|
|
double *matCbuf = c->buffer;
|
|
|
|
double *tA = A, *tB = B, *tC = C;
|
|
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
|
|
double result, scratch[8];
|
|
double *alpha_cast, *beta_cast; // alpha, beta multiples
|
|
alpha_cast = (alpha->buffer);
|
|
beta_cast = (beta->buffer);
|
|
|
|
// The non-copy version of the A^T SYRK gives better performance for the small M cases.
|
|
// The threshold is controlled by BLIS_ATBN_M_THRES
|
|
if (M <= BLIS_ATBN_M_THRES)
|
|
{
|
|
for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
|
|
{
|
|
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
|
|
{
|
|
tA = A + row_idx * lda;
|
|
tB = B + col_idx * ldb;
|
|
tC = C + col_idx * ldc + row_idx;
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
//The inner loop computes the 4x3 values of the matrix.
|
|
//The computation pattern is:
|
|
// ymm4 ymm5 ymm6
|
|
// ymm7 ymm8 ymm9
|
|
// ymm10 ymm11 ymm12
|
|
// ymm13 ymm14 ymm15
|
|
|
|
//The Dot operation is performed in the inner loop, 4 double elements fit
|
|
//in the YMM register hence loop count incremented by 4
|
|
for (k = 0; (k + 3) < K; k += 4)
|
|
{
|
|
ymm0 = _mm256_loadu_pd(tB + 0);
|
|
ymm1 = _mm256_loadu_pd(tB + ldb);
|
|
ymm2 = _mm256_loadu_pd(tB + 2 * ldb);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
|
|
ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + lda);
|
|
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
|
|
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 2 * lda);
|
|
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
|
|
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 3 * lda);
|
|
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
|
|
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
|
|
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
|
|
|
|
tA += 4;
|
|
tB += 4;
|
|
|
|
}
|
|
|
|
// if K is not a multiple of 4, padding is done before load using temproary array.
|
|
if (k < K)
|
|
{
|
|
int iter;
|
|
double data_feeder[4] = { 0.0 };
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
|
|
ymm0 = _mm256_loadu_pd(data_feeder);
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
|
|
ymm1 = _mm256_loadu_pd(data_feeder);
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
|
|
ymm2 = _mm256_loadu_pd(data_feeder);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
|
|
ymm3 = _mm256_loadu_pd(data_feeder);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
|
|
ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
|
|
ymm3 = _mm256_loadu_pd(data_feeder);
|
|
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
|
|
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
|
|
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
|
|
ymm3 = _mm256_loadu_pd(data_feeder);
|
|
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
|
|
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
|
|
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
|
|
ymm3 = _mm256_loadu_pd(data_feeder);
|
|
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
|
|
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
|
|
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
|
|
|
|
}
|
|
|
|
//horizontal addition and storage of the data.
|
|
//Results for 4x3 blocks of C is stored here
|
|
ymm4 = _mm256_hadd_pd(ymm4, ymm4);
|
|
_mm256_storeu_pd(scratch, ymm4);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
ymm7 = _mm256_hadd_pd(ymm7, ymm7);
|
|
_mm256_storeu_pd(scratch, ymm7);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[1] = result/* + tC[1] * (*beta_cast)*/;
|
|
|
|
ymm10 = _mm256_hadd_pd(ymm10, ymm10);
|
|
_mm256_storeu_pd(scratch, ymm10);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[2] = result/* + tC[2] * (*beta_cast)*/;
|
|
|
|
ymm13 = _mm256_hadd_pd(ymm13, ymm13);
|
|
_mm256_storeu_pd(scratch, ymm13);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[3] = result/* + tC[3] * (*beta_cast)*/;
|
|
|
|
|
|
tC += ldc;
|
|
ymm5 = _mm256_hadd_pd(ymm5, ymm5);
|
|
_mm256_storeu_pd(scratch, ymm5);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
ymm8 = _mm256_hadd_pd(ymm8, ymm8);
|
|
_mm256_storeu_pd(scratch, ymm8);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[1] = result/* + tC[1] * (*beta_cast)*/;
|
|
|
|
ymm11 = _mm256_hadd_pd(ymm11, ymm11);
|
|
_mm256_storeu_pd(scratch, ymm11);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[2] = result/* + tC[2] * (*beta_cast)*/;
|
|
|
|
ymm14 = _mm256_hadd_pd(ymm14, ymm14);
|
|
_mm256_storeu_pd(scratch, ymm14);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[3] = result/* + tC[3] * (*beta_cast)*/;
|
|
|
|
|
|
tC += ldc;
|
|
ymm6 = _mm256_hadd_pd(ymm6, ymm6);
|
|
_mm256_storeu_pd(scratch, ymm6);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
ymm9 = _mm256_hadd_pd(ymm9, ymm9);
|
|
_mm256_storeu_pd(scratch, ymm9);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[1] = result/* + tC[1] * (*beta_cast)*/;
|
|
|
|
ymm12 = _mm256_hadd_pd(ymm12, ymm12);
|
|
_mm256_storeu_pd(scratch, ymm12);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[2] = result/* + tC[2] * (*beta_cast)*/;
|
|
|
|
ymm15 = _mm256_hadd_pd(ymm15, ymm15);
|
|
_mm256_storeu_pd(scratch, ymm15);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[3] = result/* + tC[3] * (*beta_cast)*/;
|
|
}
|
|
}
|
|
|
|
int processed_col = col_idx;
|
|
int processed_row = row_idx;
|
|
|
|
//The edge case handling where N is not a multiple of 3
|
|
if (processed_col < N)
|
|
{
|
|
for (col_idx = processed_col; col_idx < N; col_idx += 1)
|
|
{
|
|
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
|
|
{
|
|
tA = A + row_idx * lda;
|
|
tB = B + col_idx * ldb;
|
|
tC = C + col_idx * ldc + row_idx;
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
|
|
//The inner loop computes the 4x1 values of the matrix.
|
|
//The computation pattern is:
|
|
// ymm4
|
|
// ymm7
|
|
// ymm10
|
|
// ymm13
|
|
|
|
for (k = 0; (k + 3) < K; k += 4)
|
|
{
|
|
ymm0 = _mm256_loadu_pd(tB + 0);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + lda);
|
|
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 2 * lda);
|
|
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
|
|
|
|
ymm3 = _mm256_loadu_pd(tA + 3 * lda);
|
|
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
|
|
|
|
tA += 4;
|
|
tB += 4;
|
|
}
|
|
// if K is not a multiple of 4, padding is done before load using temproary array.
|
|
if (k < K)
|
|
{
|
|
int iter;
|
|
double data_feeder[4] = { 0.0 };
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
|
|
ymm0 = _mm256_loadu_pd(data_feeder);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
|
|
ymm3 = _mm256_loadu_pd(data_feeder);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
|
|
ymm3 = _mm256_loadu_pd(data_feeder);
|
|
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
|
|
ymm3 = _mm256_loadu_pd(data_feeder);
|
|
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
|
|
ymm3 = _mm256_loadu_pd(data_feeder);
|
|
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
|
|
|
|
}
|
|
|
|
//horizontal addition and storage of the data.
|
|
//Results for 4x1 blocks of C is stored here
|
|
ymm4 = _mm256_hadd_pd(ymm4, ymm4);
|
|
_mm256_storeu_pd(scratch, ymm4);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
ymm7 = _mm256_hadd_pd(ymm7, ymm7);
|
|
_mm256_storeu_pd(scratch, ymm7);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[1] = result/* + tC[1] * (*beta_cast)*/;
|
|
|
|
ymm10 = _mm256_hadd_pd(ymm10, ymm10);
|
|
_mm256_storeu_pd(scratch, ymm10);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[2] = result/* + tC[2] * (*beta_cast)*/;
|
|
|
|
ymm13 = _mm256_hadd_pd(ymm13, ymm13);
|
|
_mm256_storeu_pd(scratch, ymm13);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[3] = result/* + tC[3] * (*beta_cast)*/;
|
|
|
|
}
|
|
}
|
|
processed_row = row_idx;
|
|
}
|
|
|
|
// The edge case handling where M is not a multiple of 4
|
|
if (processed_row < M)
|
|
{
|
|
for (row_idx = processed_row; row_idx < M; row_idx += 1)
|
|
{
|
|
for (col_idx = 0; col_idx < N; col_idx += 1)
|
|
{
|
|
tA = A + row_idx * lda;
|
|
tB = B + col_idx * ldb;
|
|
tC = C + col_idx * ldc + row_idx;
|
|
// clear scratch registers.
|
|
ymm4 = _mm256_setzero_pd();
|
|
|
|
for (k = 0; (k + 3) < K; k += 4)
|
|
{
|
|
ymm0 = _mm256_loadu_pd(tB + 0);
|
|
ymm3 = _mm256_loadu_pd(tA);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
|
|
tA += 4;
|
|
tB += 4;
|
|
}
|
|
|
|
// if K is not a multiple of 4, padding is done before load using temproary array.
|
|
if (k < K)
|
|
{
|
|
int iter;
|
|
double data_feeder[4] = { 0.0 };
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
|
|
ymm0 = _mm256_loadu_pd(data_feeder);
|
|
|
|
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
|
|
ymm3 = _mm256_loadu_pd(data_feeder);
|
|
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
|
|
|
|
}
|
|
|
|
//horizontal addition and storage of the data.
|
|
ymm4 = _mm256_hadd_pd(ymm4, ymm4);
|
|
_mm256_storeu_pd(scratch, ymm4);
|
|
result = scratch[0] + scratch[2];
|
|
result *= (*alpha_cast);
|
|
tC[0] = result/* + tC[0] * (*beta_cast)*/;
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
//copy/compute sryk values back to C
|
|
if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C
|
|
{
|
|
dim_t _i, _j;
|
|
if(bli_obj_is_lower(c)) //c is lower
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
for ( _i = 0; _i < M; ++_i )
|
|
if ( (doff_t)_j - (doff_t)_i <= 0 )
|
|
{
|
|
bli_ddcopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
}
|
|
else //c is upper
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
for ( _i = 0; _i < M; ++_i )
|
|
if ( (doff_t)_j - (doff_t)_i >= 0 )
|
|
{
|
|
bli_ddcopys( *(C + _i*rsc + _j*ldc),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
}
|
|
}
|
|
else //when beta is non-zero, multiply and store result to C
|
|
{
|
|
dim_t _i, _j;
|
|
if(bli_obj_is_lower(c)) //c is lower
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
for ( _i = 0; _i < M; ++_i )
|
|
if ( (doff_t)_j - (doff_t)_i <= 0 )
|
|
{
|
|
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
}
|
|
else //c is upper
|
|
{
|
|
for ( _j = 0; _j < N; ++_j )
|
|
for ( _i = 0; _i < M; ++_i )
|
|
if ( (doff_t)_j - (doff_t)_i >= 0 )
|
|
{
|
|
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
|
|
*(beta_cast),
|
|
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
|
|
}
|
|
}
|
|
}
|
|
|
|
return BLIS_SUCCESS;
|
|
}
|
|
else
|
|
return BLIS_NONCONFORMAL_DIMENSIONS;
|
|
}
|
|
|
|
#endif
|
|
|