Files
blis/kernels/zen/3/bli_syrk_small.c
Field G. Van Zee 29b0e1ef4e Code review + tweaks to AMD's AOCL 2.0 PR (#349).
Details:
- NOTE: This is a merge commit of 'master' of git://github.com/amd/blis
  into 'amd-master' of flame/blis.
- Fixed a bug in the downstream value of BLIS_NUM_ARCHS, which was
  inadvertantly not incremented when the Zen2 subconfiguration was
  added.
- In bli_gemm_front(), added a missing conditional constraint around the
  call to bli_gemm_small() that ensures that the computation precision
  of C matches the storage precision of C.
- In bli_syrk_front(), reorganized and relocated the notrans/trans logic
  that existed around the call to bli_syrk_small() into bli_syrk_small()
  to minimize the calling code footprint and also to bring that code
  into stylistic harmony with similar code in bli_gemm_front() and
  bli_trsm_front(). Also, replaced direct accessing of obj_t fields with
  proper accessor static functions (e.g. 'a->dim[0]' becomes
  'bli_obj_length( a )').
- Added #ifdef BLIS_ENABLE_SMALL_MATRIX guard around prototypes for
  bli_gemm_small(), bli_syrk_small(), and bli_trsm_small(). This is
  strictly speaking unnecessary, but it serves as a useful visual cue to
  those who may be reading the files.
- Removed cpp macro-protected small matrix debugging code from
  bli_trsm_front.c.
- Added a GCC_OT_9_1_0 variable to build/config.mk.in to facilitate gcc
  version check for availability of -march=znver2, and added appropriate
  support to configure script.
- Cleanups to compiler flags common to recent AMD microarchitectures in
  config/zen/amd_config.mk, including: removal of -march=znver1 et al.
  from CKVECFLAGS (since the -march flag is added within make_defs.mk);
  setting CRVECFLAGS similarly to CKVECFLAGS.
- Cleanups to config/zen/bli_cntx_init_zen.c.
- Cleanups, added comments to config/zen/make_defs.mk.
- Cleanups to config/zen2/make_defs.mk, including making use of newly-
  added GCC_OT_9_1_0 and existing GCC_OT_6_1_0 to choose the correct
  set of compiler flags based on the version of gcc being used.
- Reverted downstream changes to test/test_gemm.c.
- Various whitespace/comment changes.
2019-10-11 10:24:24 -05:00

4211 lines
166 KiB
C

/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name of The University of Texas at Austin nor the names
of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "immintrin.h"
#include "xmmintrin.h"
#include "blis.h"
#ifdef BLIS_ENABLE_SMALL_MATRIX
#define MR 32
#define D_MR (MR >> 1)
#define NR 3
#define BLIS_ENABLE_PREFETCH
#define F_SCRATCH_DIM (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)
static float A_pack[F_SCRATCH_DIM] __attribute__((aligned(64)));
static float C_pack[F_SCRATCH_DIM] __attribute__((aligned(64)));
#define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 )
#define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2)
#define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2)
#define D_SCRATCH_DIM (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)
static double D_A_pack[D_SCRATCH_DIM] __attribute__((aligned(64)));
static double D_C_pack[D_SCRATCH_DIM] __attribute__((aligned(64)));
#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called.
#define AT_MR 4 // The kernel dimension of the A transpose SYRK kernel.(AT_MR * NR).
static err_t bli_ssyrk_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
static err_t bli_dsyrk_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
static err_t bli_ssyrk_small_atbn
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
static err_t bli_dsyrk_small_atbn
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
/*
* The bli_syrk_small function will use the
* custom MRxNR kernels, to perform the computation.
* The custom kernels are used if the [M * N] < 240 * 240
*/
err_t bli_syrk_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
// FGVZ: This code was originally in bli_syrk_front(). However, it really
// fits more naturally here within the bli_syrk_small() function. This
// becomes a bit more obvious now that the code is here, as it contains
// cpp macros such as BLIS_SMALL_MATRIX_A_THRES_M_SYRK, which are specific
// to this implementation.
if ( bli_obj_has_trans( a ) )
{
// Continue with small implementation.
;
}
else if ( ( bli_obj_length( a ) <= BLIS_SMALL_MATRIX_A_THRES_M_SYRK &&
bli_obj_width( a ) < BLIS_SMALL_MATRIX_A_THRES_N_SYRK ) ||
( bli_obj_length( a ) < BLIS_SMALL_MATRIX_A_THRES_M_SYRK &&
bli_obj_width( a ) <= BLIS_SMALL_MATRIX_A_THRES_N_SYRK ) )
{
// Continue with small implementation.
;
}
else
{
// Reject the problem and return to large code path.
return BLIS_FAILURE;
}
#ifdef BLIS_ENABLE_MULTITHREADING
return BLIS_NOT_YET_IMPLEMENTED;
#endif
// If alpha is zero, scale by beta and return.
if (bli_obj_equals(alpha, &BLIS_ZERO))
{
return BLIS_NOT_YET_IMPLEMENTED;
}
// if row major format return.
if ((bli_obj_row_stride( a ) != 1) ||
(bli_obj_row_stride( b ) != 1) ||
(bli_obj_row_stride( c ) != 1))
{
return BLIS_INVALID_ROW_STRIDE;
}
num_t dt = ((*c).info & (0x7 << 0));
if (bli_obj_has_trans( a ))
{
if (bli_obj_has_notrans( b ))
{
if (dt == BLIS_FLOAT)
{
return bli_ssyrk_small_atbn(alpha, a, b, beta, c, cntx, cntl);
}
else if (dt == BLIS_DOUBLE)
{
return bli_dsyrk_small_atbn(alpha, a, b, beta, c, cntx, cntl);
}
}
return BLIS_NOT_YET_IMPLEMENTED;
}
if (dt == BLIS_DOUBLE)
{
return bli_dsyrk_small(alpha, a, b, beta, c, cntx, cntl);
}
if (dt == BLIS_FLOAT)
{
return bli_ssyrk_small(alpha, a, b, beta, c, cntx, cntl);
}
return BLIS_NOT_YET_IMPLEMENTED;
};
static err_t bli_ssyrk_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
int M = bli_obj_length( c ); // number of rows of Matrix C
int N = bli_obj_width( c ); // number of columns of Matrix C
int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
int L = M * N;
if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES))
|| ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
{
int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
int row_idx, col_idx, k;
int rs_matC = bli_obj_row_stride( c );
int rsc = 1;
float *A = a->buffer; // pointer to elements of Matrix A
float *B = b->buffer; // pointer to elements of Matrix B
float *C = C_pack; // pointer to elements of Matrix C
float *matCbuf = c->buffer;
float *tA = A, *tB = B, *tC = C;//, *tA_pack;
float *tA_packed; // temprorary pointer to hold packed A memory pointer
int row_idx_packed; //packed A memory row index
int lda_packed; //lda of packed A
int col_idx_start; //starting index after A matrix is packed.
dim_t tb_inc_row = 1; // row stride of matrix B
dim_t tb_inc_col = ldb; // column stride of matrix B
__m256 ymm4, ymm5, ymm6, ymm7;
__m256 ymm8, ymm9, ymm10, ymm11;
__m256 ymm12, ymm13, ymm14, ymm15;
__m256 ymm0, ymm1, ymm2, ymm3;
int n_remainder; // If the N is non multiple of 3.(N%3)
int m_remainder; // If the M is non multiple of 32.(M%32)
float *alpha_cast, *beta_cast; // alpha, beta multiples
alpha_cast = (alpha->buffer);
beta_cast = (beta->buffer);
int required_packing_A = 1;
// when N is equal to 1 call GEMV instead of SYRK
if (N == 1)
{
bli_gemv
(
alpha,
a,
b,
beta,
c
);
return BLIS_SUCCESS;
}
//update the pointer math if matrix B needs to be transposed.
if (bli_obj_has_trans( b ))
{
tb_inc_col = 1; //switch row and column strides
tb_inc_row = ldb;
}
if ((N <= 3) || ((MR * K) > F_SCRATCH_DIM))
{
required_packing_A = 0;
}
/*
* The computation loop runs for MRxN columns of C matrix, thus
* accessing the MRxK A matrix data and KxNR B matrix data.
* The computation is organized as inner loops of dimension MRxNR.
*/
// Process MR rows of C matrix at a time.
for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR)
{
col_idx_start = 0;
tA_packed = A;
row_idx_packed = row_idx;
lda_packed = lda;
// This is the part of the pack and compute optimization.
// During the first column iteration, we store the accessed A matrix into
// contiguous static memory. This helps to keep te A matrix in Cache and
// aviods the TLB misses.
if (required_packing_A)
{
col_idx = 0;
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
tA_packed = A_pack;
#if 0//def BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm11 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
// This loop is processing MR x K
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
_mm256_storeu_ps(tA_packed, ymm3); // the packing of matrix A
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
_mm256_storeu_ps(tA_packed + 8, ymm3); // the packing of matrix A
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
_mm256_storeu_ps(tA_packed + 16, ymm3); // the packing of matrix A
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
ymm3 = _mm256_loadu_ps(tA + 24);
_mm256_storeu_ps(tA_packed + 24, ymm3); // the packing of matrix A
// ymm7 += ymm0 * ymm3;
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
// ymm11 += ymm1 * ymm3;
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
// ymm15 += ymm2 * ymm3;
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
tA += lda;
tA_packed += MR;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm11 = _mm256_mul_ps(ymm11, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
/*ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
_mm256_storeu_ps(tC + 16, ymm6);
_mm256_storeu_ps(tC + 24, ymm7);
// multiply C by beta and accumulate, col 2.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
_mm256_storeu_ps(tC + 24, ymm11);
// multiply C by beta and accumulate, col 3.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
_mm256_storeu_ps(tC + 24, ymm15);
// modify the pointer arithematic to use packed A matrix.
col_idx_start = NR;
tA_packed = A_pack;
row_idx_packed = 0;
lda_packed = MR;
}
// Process NR columns of C matrix at a time.
for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = tA_packed + row_idx_packed;
#if 0//def BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm11 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
// This loop is processing MR x K
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
ymm3 = _mm256_loadu_ps(tA + 24);
// ymm7 += ymm0 * ymm3;
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
// ymm11 += ymm1 * ymm3;
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
// ymm15 += ymm2 * ymm3;
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
tA += lda_packed;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm11 = _mm256_mul_ps(ymm11, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
/*ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
_mm256_storeu_ps(tC + 16, ymm6);
_mm256_storeu_ps(tC + 24, ymm7);
// multiply C by beta and accumulate, col 2.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
_mm256_storeu_ps(tC + 24, ymm11);
// multiply C by beta and accumulate, col 3.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
_mm256_storeu_ps(tC + 24, ymm15);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm11 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
ymm3 = _mm256_loadu_ps(tA + 24);
ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);
ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm11 = _mm256_mul_ps(ymm11, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate, col 1.
/*ymm2 = _mm256_loadu_ps(tC + 0);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/
_mm256_storeu_ps(tC + 0, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
_mm256_storeu_ps(tC + 24, ymm11);
// multiply C by beta and accumulate, col 2.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
_mm256_storeu_ps(tC + 24, ymm15);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
ymm3 = _mm256_loadu_ps(tA + 24);
ymm15 = _mm256_fmadd_ps(ymm0, ymm3, ymm15);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC + 0);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
_mm256_storeu_ps(tC + 0, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
_mm256_storeu_ps(tC + 24, ymm15);
}
}
m_remainder = M - row_idx;
if (m_remainder >= 24)
{
m_remainder -= 24;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
_mm256_storeu_ps(tC + 16, ymm6);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC + 0);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/
_mm256_storeu_ps(tC + 0, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC + 0);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/
_mm256_storeu_ps(tC + 0, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
}
row_idx += 24;
}
if (m_remainder >= 16)
{
m_remainder -= 16;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
ymm8 = _mm256_fmadd_ps(ymm2, ymm3, ymm8);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
_mm256_storeu_ps(tC, ymm6);
_mm256_storeu_ps(tC + 8, ymm7);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
_mm256_storeu_ps(tC, ymm6);
_mm256_storeu_ps(tC + 8, ymm7);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
}
row_idx += 16;
}
if (m_remainder >= 8)
{
m_remainder -= 8;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/
_mm256_storeu_ps(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
_mm256_storeu_ps(tC, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/
_mm256_storeu_ps(tC, ymm6);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/
_mm256_storeu_ps(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_ps(tC);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
_mm256_storeu_ps(tC, ymm5);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
ymm4 = _mm256_mul_ps(ymm4, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/
_mm256_storeu_ps(tC, ymm4);
}
row_idx += 8;
}
// M is not a multiple of 32.
// The handling of edge case where the remainder
// dimension is less than 8. The padding takes place
// to handle this case.
if ((m_remainder) && (lda > 7))
{
float f_temp[8];
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm5 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
_mm256_storeu_ps(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
_mm256_storeu_ps(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/
_mm256_storeu_ps(f_temp, ymm9);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm5 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
ymm3 = _mm256_loadu_ps(tA);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
tA += lda;
}
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
_mm256_storeu_ps(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
_mm256_storeu_ps(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm5 = _mm256_setzero_ps();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
ymm3 = _mm256_loadu_ps(tA);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
tA += lda;
}
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm0 = _mm256_broadcast_ss(alpha_cast);
//ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm5 = _mm256_mul_ps(ymm5, ymm0);
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
_mm256_storeu_ps(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
m_remainder = 0;
}
if (m_remainder)
{
float result;
for (; row_idx < M; row_idx += 1)
{
for (col_idx = 0; col_idx < N; col_idx += 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
result = 0;
for (k = 0; k < K; ++k)
{
result += (*tA) * (*tB);
tA += lda;
tB += tb_inc_row;
}
result *= (*alpha_cast);
(*tC) = /*(*tC) * (*beta_cast) + */result;
}
}
}
//copy/compute sryk values back to C using SIMD
if ( bli_seq0( *beta_cast ) )
{//just copy in case of beta = 0
dim_t _i, _j, k, _l;
if(bli_obj_is_lower(c)) // c is lower
{
//first column
_j = 0;
k = M >> 3;
_i = 0;
for ( _l = 0; _l < k; _l++ )
{
ymm0 = _mm256_loadu_ps((C + _i*rsc));
_mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0);
_i += 8;
}
while (_i < M )
{
bli_sscopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
_i++;
}
_j++;
while ( _j < N ) //next column
{
//k = (_j + (8 - (_j & 7)));
_l = _j & 7;
k = (_l != 0) ? (_j + (8 - _l)) : _j;
k = (k <= M) ? k : M;
for ( _i = _j; _i < k; ++_i )
{
bli_sscopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
k = (M - _i) >> 3;
_l = 0;
while ( _l < k )
{
ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
_mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
_i += 8;
_l++;
}
while (_i < M )
{
bli_sscopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
_i++;
}
_j++;
}
}
else //c is upper
{
for ( _j = 0; _j < N; ++_j )
{
k = (_j + 1) >> 3;
_i = 0;
_l = 0;
while ( _l < k )
{
ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
_mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
_i += 8;
_l++;
}
while (_i <= _j )
{
bli_sscopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
++_i;
}
}
}
}
else
{//when beta is non-zero, fmadd and store the results
dim_t _i, _j, k, _l;
ymm1 = _mm256_broadcast_ss(beta_cast);
if(bli_obj_is_lower(c)) //c is lower
{
//first column
_j = 0;
k = M >> 3;
_i = 0;
for ( _l = 0; _l < k; _l++ )
{
ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC));
ymm0 = _mm256_loadu_ps((C + _i*rsc));
ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0);
_mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0);
_i += 8;
}
while (_i < M )
{
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
_i++;
}
_j++;
while ( _j < N ) //next column
{
//k = (_j + (8 - (_j & 7)));
_l = _j & 7;
k = (_l != 0) ? (_j + (8 - _l)) : _j;
k = (k <= M) ? k : M;
for ( _i = _j; _i < k; ++_i )
{
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
k = (M - _i) >> 3;
_l = 0;
while ( _l < k )
{
ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC));
ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0);
_mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
_i += 8;
_l++;
}
while (_i < M )
{
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
_i++;
}
_j++;
}
}
else //c is upper
{
for ( _j = 0; _j < N; ++_j )
{
k = (_j + 1) >> 3;
_i = 0;
_l = 0;
while ( _l < k )
{
ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC));
ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0);
_mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
_i += 8;
_l++;
}
while (_i <= _j )
{
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
++_i;
}
}
}
}
return BLIS_SUCCESS;
}
else
return BLIS_NONCONFORMAL_DIMENSIONS;
};
static err_t bli_dsyrk_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
int M = bli_obj_length( c ); // number of rows of Matrix C
int N = bli_obj_width( c ); // number of columns of Matrix C
int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
int L = M * N;
// If alpha is zero, scale by beta and return.
if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES))
|| ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
{
int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
int row_idx, col_idx, k;
int rs_matC = bli_obj_row_stride( c );
int rsc = 1;
double *A = a->buffer; // pointer to elements of Matrix A
double *B = b->buffer; // pointer to elements of Matrix B
double *C = D_C_pack; // pointer to elements of Matrix C
double *matCbuf = c->buffer;
double *tA = A, *tB = B, *tC = C;//, *tA_pack;
double *tA_packed; // temprorary pointer to hold packed A memory pointer
int row_idx_packed; //packed A memory row index
int lda_packed; //lda of packed A
int col_idx_start; //starting index after A matrix is packed.
dim_t tb_inc_row = 1; // row stride of matrix B
dim_t tb_inc_col = ldb; // column stride of matrix B
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm0, ymm1, ymm2, ymm3;
int n_remainder; // If the N is non multiple of 3.(N%3)
int m_remainder; // If the M is non multiple of 16.(M%16)
double *alpha_cast, *beta_cast; // alpha, beta multiples
alpha_cast = (alpha->buffer);
beta_cast = (beta->buffer);
int required_packing_A = 1;
// when N is equal to 1 call GEMV instead of SYRK
if (N == 1)
{
bli_gemv
(
alpha,
a,
b,
beta,
c
);
return BLIS_SUCCESS;
}
//update the pointer math if matrix B needs to be transposed.
if (bli_obj_has_trans( b ))
{
tb_inc_col = 1; //switch row and column strides
tb_inc_row = ldb;
}
if ((N <= 3) || ((D_MR * K) > D_SCRATCH_DIM))
{
required_packing_A = 0;
}
/*
* The computation loop runs for D_MRxN columns of C matrix, thus
* accessing the D_MRxK A matrix data and KxNR B matrix data.
* The computation is organized as inner loops of dimension D_MRxNR.
*/
// Process D_MR rows of C matrix at a time.
for (row_idx = 0; (row_idx + (D_MR - 1)) < M; row_idx += D_MR)
{
col_idx_start = 0;
tA_packed = A;
row_idx_packed = row_idx;
lda_packed = lda;
// This is the part of the pack and compute optimization.
// During the first column iteration, we store the accessed A matrix into
// contiguous static memory. This helps to keep te A matrix in Cache and
// aviods the TLB misses.
if (required_packing_A)
{
col_idx = 0;
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
tA_packed = D_A_pack;
#if 0//def BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
// This loop is processing D_MR x K
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
_mm256_storeu_pd(tA_packed, ymm3); // the packing of matrix A
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
_mm256_storeu_pd(tA_packed + 4, ymm3); // the packing of matrix A
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
_mm256_storeu_pd(tA_packed + 8, ymm3); // the packing of matrix A
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
ymm3 = _mm256_loadu_pd(tA + 12);
_mm256_storeu_pd(tA_packed + 12, ymm3); // the packing of matrix A
// ymm7 += ymm0 * ymm3;
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
// ymm11 += ymm1 * ymm3;
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
// ymm15 += ymm2 * ymm3;
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
tA += lda;
tA_packed += D_MR;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
/*ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
_mm256_storeu_pd(tC + 8, ymm6);
_mm256_storeu_pd(tC + 12, ymm7);
// multiply C by beta and accumulate, col 2.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/
_mm256_storeu_pd(tC, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
_mm256_storeu_pd(tC + 12, ymm11);
// multiply C by beta and accumulate, col 3.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
_mm256_storeu_pd(tC + 12, ymm15);
// modify the pointer arithematic to use packed A matrix.
col_idx_start = NR;
tA_packed = D_A_pack;
row_idx_packed = 0;
lda_packed = D_MR;
}
// Process NR columns of C matrix at a time.
for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = tA_packed + row_idx_packed;
#if 0//def BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
// This loop is processing D_MR x K
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
ymm3 = _mm256_loadu_pd(tA + 12);
// ymm7 += ymm0 * ymm3;
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
// ymm11 += ymm1 * ymm3;
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
// ymm15 += ymm2 * ymm3;
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
tA += lda_packed;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
/*ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
_mm256_storeu_pd(tC + 8, ymm6);
_mm256_storeu_pd(tC + 12, ymm7);
// multiply C by beta and accumulate, col 2.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/
_mm256_storeu_pd(tC, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
_mm256_storeu_pd(tC + 12, ymm11);
// multiply C by beta and accumulate, col 3.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
_mm256_storeu_pd(tC + 12, ymm15);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
ymm3 = _mm256_loadu_pd(tA + 12);
ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11);
ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
// multiply C by beta and accumulate, col 1.
/*ymm2 = _mm256_loadu_pd(tC + 0);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/
_mm256_storeu_pd(tC + 0, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
_mm256_storeu_pd(tC + 12, ymm11);
// multiply C by beta and accumulate, col 2.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
_mm256_storeu_pd(tC + 12, ymm15);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
ymm3 = _mm256_loadu_pd(tA + 12);
ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC + 0);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
_mm256_storeu_pd(tC + 0, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
_mm256_storeu_pd(tC + 12, ymm15);
}
}
m_remainder = M - row_idx;
if (m_remainder >= 12)
{
m_remainder -= 12;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
_mm256_storeu_pd(tC + 8, ymm6);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/
_mm256_storeu_pd(tC, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC + 0);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/
_mm256_storeu_pd(tC + 0, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC + 0);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/
_mm256_storeu_pd(tC + 0, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
}
row_idx += 12;
}
if (m_remainder >= 8)
{
m_remainder -= 8;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
_mm256_storeu_pd(tC, ymm6);
_mm256_storeu_pd(tC + 4, ymm7);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/
_mm256_storeu_pd(tC, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
_mm256_storeu_pd(tC, ymm6);
_mm256_storeu_pd(tC + 4, ymm7);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
}
row_idx += 8;
}
if (m_remainder >= 4)
{
m_remainder -= 4;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/
_mm256_storeu_pd(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
_mm256_storeu_pd(tC, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/
_mm256_storeu_pd(tC, ymm6);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/
_mm256_storeu_pd(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
/*ymm2 = _mm256_loadu_pd(tC);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
_mm256_storeu_pd(tC, ymm5);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
ymm4 = _mm256_mul_pd(ymm4, ymm0);
// multiply C by beta and accumulate.
/*ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/
_mm256_storeu_pd(tC, ymm4);
}
row_idx += 4;
}
// M is not a multiple of 32.
// The handling of edge case where the remainder
// dimension is less than 8. The padding takes place
// to handle this case.
if ((m_remainder) && (lda > 3))
{
double f_temp[8];
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm5 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
_mm256_storeu_pd(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
_mm256_storeu_pd(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/
_mm256_storeu_pd(f_temp, ymm9);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm5 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
ymm3 = _mm256_loadu_pd(tA);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
tA += lda;
}
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
_mm256_storeu_pd(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
_mm256_storeu_pd(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm5 = _mm256_setzero_pd();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
ymm3 = _mm256_loadu_pd(tA);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
tA += lda;
}
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm0 = _mm256_broadcast_sd(alpha_cast);
//ymm1 = _mm256_broadcast_sd(beta_cast);
// multiply C by beta and accumulate.
ymm5 = _mm256_mul_pd(ymm5, ymm0);
/*for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
_mm256_storeu_pd(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
m_remainder = 0;
}
if (m_remainder)
{
double result;
for (; row_idx < M; row_idx += 1)
{
for (col_idx = 0; col_idx < N; col_idx += 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
result = 0;
for (k = 0; k < K; ++k)
{
result += (*tA) * (*tB);
tA += lda;
tB += tb_inc_row;
}
result *= (*alpha_cast);
(*tC) = /*(*tC) * (*beta_cast) + */result;
}
}
}
//copy/compute sryk values back to C using SIMD
if ( bli_seq0( *beta_cast ) )
{//just copy for beta = 0
dim_t _i, _j, k, _l;
if(bli_obj_is_lower(c)) //c is lower
{
//first column
_j = 0;
k = M >> 2;
_i = 0;
for ( _l = 0; _l < k; _l++ )
{
ymm0 = _mm256_loadu_pd((C + _i*rsc));
_mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0);
_i += 4;
}
while (_i < M )
{
bli_ddcopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
_i++;
}
_j++;
while ( _j < N ) //next column
{
//k = (_j + (4 - (_j & 3)));
_l = _j & 3;
k = (_l != 0) ? (_j + (4 - _l)) : _j;
k = (k <= M) ? k : M;
for ( _i = _j; _i < k; ++_i )
{
bli_ddcopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
k = (M - _i) >> 2;
_l = 0;
while ( _l < k )
{
ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
_mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
_i += 4;
_l++;
}
while (_i < M )
{
bli_ddcopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
_i++;
}
_j++;
}
}
else //c is upper
{
for ( _j = 0; _j < N; ++_j )
{
k = (_j + 1) >> 2;
_i = 0;
_l = 0;
while ( _l < k )
{
ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
_mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
_i += 4;
_l++;
}
while (_i <= _j )
{
bli_ddcopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
++_i;
}
}
}
}
else
{//when beta is non-zero, fmadd and store the results
dim_t _i, _j, k, _l;
ymm1 = _mm256_broadcast_sd(beta_cast);
if(bli_obj_is_lower(c)) //c is lower
{
//first column
_j = 0;
k = M >> 2;
_i = 0;
for ( _l = 0; _l < k; _l++ )
{
ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC));
ymm0 = _mm256_loadu_pd((C + _i*rsc));
ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0);
_mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0);
_i += 4;
}
while (_i < M )
{
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
_i++;
}
_j++;
while ( _j < N ) //next column
{
//k = (_j + (4 - (_j & 3)));
_l = _j & 3;
k = (_l != 0) ? (_j + (4 - _l)) : _j;
k = (k <= M) ? k : M;
for ( _i = _j; _i < k; ++_i )
{
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
k = (M - _i) >> 2;
_l = 0;
while ( _l < k )
{
ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC));
ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0);
_mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
_i += 4;
_l++;
}
while (_i < M )
{
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
_i++;
}
_j++;
}
}
else //c is upper
{
for ( _j = 0; _j < N; ++_j )
{
k = (_j + 1) >> 2;
_i = 0;
_l = 0;
while ( _l < k )
{
ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC));
ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0);
_mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
_i += 4;
_l++;
}
while (_i <= _j )
{
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
++_i;
}
}
}
}
return BLIS_SUCCESS;
}
else
return BLIS_NONCONFORMAL_DIMENSIONS;
};
static err_t bli_ssyrk_small_atbn
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
int M = bli_obj_length(c); // number of rows of Matrix C
int N = bli_obj_width(c); // number of columns of Matrix C
int K = bli_obj_length(b); // number of rows of Matrix B
int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
int row_idx = 0, col_idx = 0, k;
int rs_matC = bli_obj_row_stride( c );
int rsc = 1;
float *A = a->buffer; // pointer to matrix A elements, stored in row major format
float *B = b->buffer; // pointer to matrix B elements, stored in column major format
float *C = C_pack; // pointer to matrix C elements, stored in column major format
float *matCbuf = c->buffer;
float *tA = A, *tB = B, *tC = C;
__m256 ymm4, ymm5, ymm6, ymm7;
__m256 ymm8, ymm9, ymm10, ymm11;
__m256 ymm12, ymm13, ymm14, ymm15;
__m256 ymm0, ymm1, ymm2, ymm3;
float result, scratch[8];
float *alpha_cast, *beta_cast; // alpha, beta multiples
alpha_cast = (alpha->buffer);
beta_cast = (beta->buffer);
// The non-copy version of the A^T SYRK gives better performance for the small M cases.
// The threshold is controlled by BLIS_ATBN_M_THRES
if (M <= BLIS_ATBN_M_THRES)
{
for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
{
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm11 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
//The inner loop computes the 4x3 values of the matrix.
//The computation pattern is:
// ymm4 ymm5 ymm6
// ymm7 ymm8 ymm9
// ymm10 ymm11 ymm12
// ymm13 ymm14 ymm15
//The Dot operation is performed in the inner loop, 8 float elements fit
//in the YMM register hence loop count incremented by 8
for (k = 0; (k + 7) < K; k += 8)
{
ymm0 = _mm256_loadu_ps(tB + 0);
ymm1 = _mm256_loadu_ps(tB + ldb);
ymm2 = _mm256_loadu_ps(tB + 2 * ldb);
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
ymm3 = _mm256_loadu_ps(tA + lda);
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
ymm3 = _mm256_loadu_ps(tA + 2 * lda);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 3 * lda);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
tA += 8;
tB += 8;
}
// if K is not a multiple of 8, padding is done before load using temproary array.
if (k < K)
{
int iter;
float data_feeder[8] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
ymm1 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
ymm2 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
}
//horizontal addition and storage of the data.
//Results for 4x3 blocks of C is stored here
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
_mm256_storeu_ps(scratch, ymm4);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
_mm256_storeu_ps(scratch, ymm7);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[1] = result/* + tC[1] * (*beta_cast)*/;
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
_mm256_storeu_ps(scratch, ymm10);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[2] = result/* + tC[2] * (*beta_cast)*/;
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
_mm256_storeu_ps(scratch, ymm13);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[3] = result/* + tC[3] * (*beta_cast)*/;
tC += ldc;
ymm5 = _mm256_hadd_ps(ymm5, ymm5);
ymm5 = _mm256_hadd_ps(ymm5, ymm5);
_mm256_storeu_ps(scratch, ymm5);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
_mm256_storeu_ps(scratch, ymm8);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[1] = result/* + tC[1] * (*beta_cast)*/;
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
_mm256_storeu_ps(scratch, ymm11);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[2] = result/* + tC[2] * (*beta_cast)*/;
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
_mm256_storeu_ps(scratch, ymm14);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[3] = result/* + tC[3] * (*beta_cast)*/;
tC += ldc;
ymm6 = _mm256_hadd_ps(ymm6, ymm6);
ymm6 = _mm256_hadd_ps(ymm6, ymm6);
_mm256_storeu_ps(scratch, ymm6);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
_mm256_storeu_ps(scratch, ymm9);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[1] = result/* + tC[1] * (*beta_cast)*/;
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
_mm256_storeu_ps(scratch, ymm12);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[2] = result/* + tC[2] * (*beta_cast)*/;
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
_mm256_storeu_ps(scratch, ymm15);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[3] = result/* + tC[3] * (*beta_cast)*/;
}
}
int processed_col = col_idx;
int processed_row = row_idx;
//The edge case handling where N is not a multiple of 3
if (processed_col < N)
{
for (col_idx = processed_col; col_idx < N; col_idx += 1)
{
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
//The inner loop computes the 4x1 values of the matrix.
//The computation pattern is:
// ymm4
// ymm7
// ymm10
// ymm13
for (k = 0; (k + 7) < K; k += 8)
{
ymm0 = _mm256_loadu_ps(tB + 0);
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm3 = _mm256_loadu_ps(tA + lda);
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
ymm3 = _mm256_loadu_ps(tA + 2 * lda);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm3 = _mm256_loadu_ps(tA + 3 * lda);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
tA += 8;
tB += 8;
}
// if K is not a multiple of 8, padding is done before load using temproary array.
if (k < K)
{
int iter;
float data_feeder[8] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
}
//horizontal addition and storage of the data.
//Results for 4x1 blocks of C is stored here
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
_mm256_storeu_ps(scratch, ymm4);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
_mm256_storeu_ps(scratch, ymm7);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[1] = result/* + tC[1] * (*beta_cast)*/;
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
_mm256_storeu_ps(scratch, ymm10);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[2] = result/* + tC[2] * (*beta_cast)*/;
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
_mm256_storeu_ps(scratch, ymm13);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[3] = result/* + tC[3] * (*beta_cast)*/;
}
}
processed_row = row_idx;
}
//The edge case handling where M is not a multiple of 4
if (processed_row < M)
{
for (row_idx = processed_row; row_idx < M; row_idx += 1)
{
for (col_idx = 0; col_idx < N; col_idx += 1)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
for (k = 0; (k + 7) < K; k += 8)
{
ymm0 = _mm256_loadu_ps(tB + 0);
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
tA += 8;
tB += 8;
}
// if K is not a multiple of 8, padding is done before load using temproary array.
if (k < K)
{
int iter;
float data_feeder[8] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
}
//horizontal addition and storage of the data.
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
_mm256_storeu_ps(scratch, ymm4);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
}
}
}
//copy/compute sryk values back to C
if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C
{
dim_t _i, _j;
if(bli_obj_is_lower(c)) //c is lower
{
for ( _j = 0; _j < N; ++_j )
for ( _i = 0; _i < M; ++_i )
if ( (doff_t)_j - (doff_t)_i <= 0 )
{
bli_sscopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
}
else //c is upper
{
for ( _j = 0; _j < N; ++_j )
for ( _i = 0; _i < M; ++_i )
if ( (doff_t)_j - (doff_t)_i >= 0 )
{
bli_sscopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
}
}
else //when beta is non-zero, multiply and store result to C
{
dim_t _i, _j;
if(bli_obj_is_lower(c)) //c is lower
{
for ( _j = 0; _j < N; ++_j )
for ( _i = 0; _i < M; ++_i )
if ( (doff_t)_j - (doff_t)_i <= 0 )
{
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
}
else //c is upper
{
for ( _j = 0; _j < N; ++_j )
for ( _i = 0; _i < M; ++_i )
if ( (doff_t)_j - (doff_t)_i >= 0 )
{
bli_sssxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
}
}
return BLIS_SUCCESS;
}
else
return BLIS_NONCONFORMAL_DIMENSIONS;
}
static err_t bli_dsyrk_small_atbn
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
int M = bli_obj_length( c ); // number of rows of Matrix C
int N = bli_obj_width( c ); // number of columns of Matrix C
int K = bli_obj_length( b ); // number of rows of Matrix B
int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
int row_idx = 0, col_idx = 0, k;
int rs_matC = bli_obj_row_stride( c );
int rsc = 1;
double *A = a->buffer; // pointer to matrix A elements, stored in row major format
double *B = b->buffer; // pointer to matrix B elements, stored in column major format
double *C = D_C_pack; // pointer to matrix C elements, stored in column major format
double *matCbuf = c->buffer;
double *tA = A, *tB = B, *tC = C;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm0, ymm1, ymm2, ymm3;
double result, scratch[8];
double *alpha_cast, *beta_cast; // alpha, beta multiples
alpha_cast = (alpha->buffer);
beta_cast = (beta->buffer);
// The non-copy version of the A^T SYRK gives better performance for the small M cases.
// The threshold is controlled by BLIS_ATBN_M_THRES
if (M <= BLIS_ATBN_M_THRES)
{
for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
{
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
//The inner loop computes the 4x3 values of the matrix.
//The computation pattern is:
// ymm4 ymm5 ymm6
// ymm7 ymm8 ymm9
// ymm10 ymm11 ymm12
// ymm13 ymm14 ymm15
//The Dot operation is performed in the inner loop, 4 double elements fit
//in the YMM register hence loop count incremented by 4
for (k = 0; (k + 3) < K; k += 4)
{
ymm0 = _mm256_loadu_pd(tB + 0);
ymm1 = _mm256_loadu_pd(tB + ldb);
ymm2 = _mm256_loadu_pd(tB + 2 * ldb);
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
ymm3 = _mm256_loadu_pd(tA + lda);
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
ymm3 = _mm256_loadu_pd(tA + 2 * lda);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 3 * lda);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
tA += 4;
tB += 4;
}
// if K is not a multiple of 4, padding is done before load using temproary array.
if (k < K)
{
int iter;
double data_feeder[4] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
ymm1 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
ymm2 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
}
//horizontal addition and storage of the data.
//Results for 4x3 blocks of C is stored here
ymm4 = _mm256_hadd_pd(ymm4, ymm4);
_mm256_storeu_pd(scratch, ymm4);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
ymm7 = _mm256_hadd_pd(ymm7, ymm7);
_mm256_storeu_pd(scratch, ymm7);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[1] = result/* + tC[1] * (*beta_cast)*/;
ymm10 = _mm256_hadd_pd(ymm10, ymm10);
_mm256_storeu_pd(scratch, ymm10);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[2] = result/* + tC[2] * (*beta_cast)*/;
ymm13 = _mm256_hadd_pd(ymm13, ymm13);
_mm256_storeu_pd(scratch, ymm13);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[3] = result/* + tC[3] * (*beta_cast)*/;
tC += ldc;
ymm5 = _mm256_hadd_pd(ymm5, ymm5);
_mm256_storeu_pd(scratch, ymm5);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
ymm8 = _mm256_hadd_pd(ymm8, ymm8);
_mm256_storeu_pd(scratch, ymm8);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[1] = result/* + tC[1] * (*beta_cast)*/;
ymm11 = _mm256_hadd_pd(ymm11, ymm11);
_mm256_storeu_pd(scratch, ymm11);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[2] = result/* + tC[2] * (*beta_cast)*/;
ymm14 = _mm256_hadd_pd(ymm14, ymm14);
_mm256_storeu_pd(scratch, ymm14);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[3] = result/* + tC[3] * (*beta_cast)*/;
tC += ldc;
ymm6 = _mm256_hadd_pd(ymm6, ymm6);
_mm256_storeu_pd(scratch, ymm6);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
ymm9 = _mm256_hadd_pd(ymm9, ymm9);
_mm256_storeu_pd(scratch, ymm9);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[1] = result/* + tC[1] * (*beta_cast)*/;
ymm12 = _mm256_hadd_pd(ymm12, ymm12);
_mm256_storeu_pd(scratch, ymm12);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[2] = result/* + tC[2] * (*beta_cast)*/;
ymm15 = _mm256_hadd_pd(ymm15, ymm15);
_mm256_storeu_pd(scratch, ymm15);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[3] = result/* + tC[3] * (*beta_cast)*/;
}
}
int processed_col = col_idx;
int processed_row = row_idx;
//The edge case handling where N is not a multiple of 3
if (processed_col < N)
{
for (col_idx = processed_col; col_idx < N; col_idx += 1)
{
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
//The inner loop computes the 4x1 values of the matrix.
//The computation pattern is:
// ymm4
// ymm7
// ymm10
// ymm13
for (k = 0; (k + 3) < K; k += 4)
{
ymm0 = _mm256_loadu_pd(tB + 0);
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm3 = _mm256_loadu_pd(tA + lda);
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
ymm3 = _mm256_loadu_pd(tA + 2 * lda);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm3 = _mm256_loadu_pd(tA + 3 * lda);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
tA += 4;
tB += 4;
}
// if K is not a multiple of 4, padding is done before load using temproary array.
if (k < K)
{
int iter;
double data_feeder[4] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
}
//horizontal addition and storage of the data.
//Results for 4x1 blocks of C is stored here
ymm4 = _mm256_hadd_pd(ymm4, ymm4);
_mm256_storeu_pd(scratch, ymm4);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
ymm7 = _mm256_hadd_pd(ymm7, ymm7);
_mm256_storeu_pd(scratch, ymm7);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[1] = result/* + tC[1] * (*beta_cast)*/;
ymm10 = _mm256_hadd_pd(ymm10, ymm10);
_mm256_storeu_pd(scratch, ymm10);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[2] = result/* + tC[2] * (*beta_cast)*/;
ymm13 = _mm256_hadd_pd(ymm13, ymm13);
_mm256_storeu_pd(scratch, ymm13);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[3] = result/* + tC[3] * (*beta_cast)*/;
}
}
processed_row = row_idx;
}
// The edge case handling where M is not a multiple of 4
if (processed_row < M)
{
for (row_idx = processed_row; row_idx < M; row_idx += 1)
{
for (col_idx = 0; col_idx < N; col_idx += 1)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
for (k = 0; (k + 3) < K; k += 4)
{
ymm0 = _mm256_loadu_pd(tB + 0);
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
tA += 4;
tB += 4;
}
// if K is not a multiple of 4, padding is done before load using temproary array.
if (k < K)
{
int iter;
double data_feeder[4] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
}
//horizontal addition and storage of the data.
ymm4 = _mm256_hadd_pd(ymm4, ymm4);
_mm256_storeu_pd(scratch, ymm4);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result/* + tC[0] * (*beta_cast)*/;
}
}
}
//copy/compute sryk values back to C
if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C
{
dim_t _i, _j;
if(bli_obj_is_lower(c)) //c is lower
{
for ( _j = 0; _j < N; ++_j )
for ( _i = 0; _i < M; ++_i )
if ( (doff_t)_j - (doff_t)_i <= 0 )
{
bli_ddcopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
}
else //c is upper
{
for ( _j = 0; _j < N; ++_j )
for ( _i = 0; _i < M; ++_i )
if ( (doff_t)_j - (doff_t)_i >= 0 )
{
bli_ddcopys( *(C + _i*rsc + _j*ldc),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
}
}
else //when beta is non-zero, multiply and store result to C
{
dim_t _i, _j;
if(bli_obj_is_lower(c)) //c is lower
{
for ( _j = 0; _j < N; ++_j )
for ( _i = 0; _i < M; ++_i )
if ( (doff_t)_j - (doff_t)_i <= 0 )
{
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
}
else //c is upper
{
for ( _j = 0; _j < N; ++_j )
for ( _i = 0; _i < M; ++_i )
if ( (doff_t)_j - (doff_t)_i >= 0 )
{
bli_dddxpbys( *(C + _i*rsc + _j*ldc),
*(beta_cast),
*(matCbuf + _i*rs_matC + _j*ldc_matC) );
}
}
}
return BLIS_SUCCESS;
}
else
return BLIS_NONCONFORMAL_DIMENSIONS;
}
#endif