Files
blis/kernels/zen/3/bli_gemm_small.c
2019-05-27 16:24:43 +05:30

3744 lines
149 KiB
C

/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "immintrin.h"
#include "xmmintrin.h"
#include "blis.h"
#ifdef BLIS_ENABLE_SMALL_MATRIX
#define MR 32
#define D_MR (MR >> 1)
#define NR 3
#define BLIS_ENABLE_PREFETCH
#define F_SCRATCH_DIM (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)
static float A_pack[F_SCRATCH_DIM] __attribute__((aligned(64)));
#define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 )
#define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2)
#define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2)
#define D_SCRATCH_DIM (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)
static double D_A_pack[D_SCRATCH_DIM] __attribute__((aligned(64)));
#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called.
#define AT_MR 4 // The kernel dimension of the A transpose GEMM kernel.(AT_MR * NR).
static err_t bli_sgemm_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
static err_t bli_dgemm_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
static err_t bli_sgemm_small_atbn
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
static err_t bli_dgemm_small_atbn
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
/*
* The bli_gemm_small function will use the
* custom MRxNR kernels, to perform the computation.
* The custom kernels are used if the [M * N] < 240 * 240
*/
err_t bli_gemm_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
#ifdef BLIS_ENABLE_MULTITHREADING
return BLIS_NOT_YET_IMPLEMENTED;
#endif
// If alpha is zero, scale by beta and return.
if (bli_obj_equals(alpha, &BLIS_ZERO))
{
return BLIS_NOT_YET_IMPLEMENTED;
}
// if row major format return.
if ((bli_obj_row_stride( a ) != 1) ||
(bli_obj_row_stride( b ) != 1) ||
(bli_obj_row_stride( c ) != 1))
{
return BLIS_INVALID_ROW_STRIDE;
}
num_t dt = ((*c).info & (0x7 << 0));
if (bli_obj_has_trans( a ))
{
if (bli_obj_has_notrans( b ))
{
if (dt == BLIS_FLOAT)
{
return bli_sgemm_small_atbn(alpha, a, b, beta, c, cntx, cntl);
}
else if (dt == BLIS_DOUBLE)
{
return bli_dgemm_small_atbn(alpha, a, b, beta, c, cntx, cntl);
}
}
return BLIS_NOT_YET_IMPLEMENTED;
}
if (dt == BLIS_DOUBLE)
{
return bli_dgemm_small(alpha, a, b, beta, c, cntx, cntl);
}
if (dt == BLIS_FLOAT)
{
return bli_sgemm_small(alpha, a, b, beta, c, cntx, cntl);
}
return BLIS_NOT_YET_IMPLEMENTED;
};
static err_t bli_sgemm_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
gint_t M = bli_obj_length( c ); // number of rows of Matrix C
gint_t N = bli_obj_width( c ); // number of columns of Matrix C
gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
gint_t L = M * N;
// printf("alpha_cast = %f beta_cast = %f [ Trans = %d %d], [stride = %d %d %d] [m,n,k = %d %d %d]\n",*alpha_cast,*beta_cast, bli_obj_has_trans( a ), bli_obj_has_trans( b ), lda, ldb,ldc, M,N,K);
if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES))
|| ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
{
guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
guint_t row_idx, col_idx, k;
float *A = a->buffer; // pointer to elements of Matrix A
float *B = b->buffer; // pointer to elements of Matrix B
float *C = c->buffer; // pointer to elements of Matrix C
float *tA = A, *tB = B, *tC = C;//, *tA_pack;
float *tA_packed; // temprorary pointer to hold packed A memory pointer
guint_t row_idx_packed; //packed A memory row index
guint_t lda_packed; //lda of packed A
guint_t col_idx_start; //starting index after A matrix is packed.
dim_t tb_inc_row = 1; // row stride of matrix B
dim_t tb_inc_col = ldb; // column stride of matrix B
__m256 ymm4, ymm5, ymm6, ymm7;
__m256 ymm8, ymm9, ymm10, ymm11;
__m256 ymm12, ymm13, ymm14, ymm15;
__m256 ymm0, ymm1, ymm2, ymm3;
gint_t n_remainder; // If the N is non multiple of 3.(N%3)
gint_t m_remainder; // If the M is non multiple of 32.(M%32)
float *alpha_cast, *beta_cast; // alpha, beta multiples
alpha_cast = (alpha->buffer);
beta_cast = (beta->buffer);
gint_t required_packing_A = 1;
// when N is equal to 1 call GEMV instead of GEMM
if (N == 1)
{
bli_gemv
(
alpha,
a,
b,
beta,
c
);
return BLIS_SUCCESS;
}
//update the pointer math if matrix B needs to be transposed.
if (bli_obj_has_trans( b ))
{
tb_inc_col = 1; //switch row and column strides
tb_inc_row = ldb;
}
if ((N <= 3) || ((MR * K) > F_SCRATCH_DIM))
{
required_packing_A = 0;
}
/*
* The computation loop runs for MRxN columns of C matrix, thus
* accessing the MRxK A matrix data and KxNR B matrix data.
* The computation is organized as inner loops of dimension MRxNR.
*/
// Process MR rows of C matrix at a time.
for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR)
{
col_idx_start = 0;
tA_packed = A;
row_idx_packed = row_idx;
lda_packed = lda;
// This is the part of the pack and compute optimization.
// During the first column iteration, we store the accessed A matrix into
// contiguous static memory. This helps to keep te A matrix in Cache and
// aviods the TLB misses.
if (required_packing_A)
{
col_idx = 0;
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
tA_packed = A_pack;
#ifdef BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm11 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
// This loop is processing MR x K
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
_mm256_storeu_ps(tA_packed, ymm3); // the packing of matrix A
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
_mm256_storeu_ps(tA_packed + 8, ymm3); // the packing of matrix A
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
_mm256_storeu_ps(tA_packed + 16, ymm3); // the packing of matrix A
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
ymm3 = _mm256_loadu_ps(tA + 24);
_mm256_storeu_ps(tA_packed + 24, ymm3); // the packing of matrix A
// ymm7 += ymm0 * ymm3;
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
// ymm11 += ymm1 * ymm3;
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
// ymm15 += ymm2 * ymm3;
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
tA += lda;
tA_packed += MR;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm11 = _mm256_mul_ps(ymm11, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
_mm256_storeu_ps(tC + 16, ymm6);
_mm256_storeu_ps(tC + 24, ymm7);
// multiply C by beta and accumulate, col 2.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
_mm256_storeu_ps(tC + 24, ymm11);
// multiply C by beta and accumulate, col 3.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
_mm256_storeu_ps(tC + 24, ymm15);
// modify the pointer arithematic to use packed A matrix.
col_idx_start = NR;
tA_packed = A_pack;
row_idx_packed = 0;
lda_packed = MR;
}
// Process NR columns of C matrix at a time.
for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = tA_packed + row_idx_packed;
#ifdef BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm11 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
// This loop is processing MR x K
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
ymm3 = _mm256_loadu_ps(tA + 24);
// ymm7 += ymm0 * ymm3;
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
// ymm11 += ymm1 * ymm3;
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
// ymm15 += ymm2 * ymm3;
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
tA += lda_packed;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm11 = _mm256_mul_ps(ymm11, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
_mm256_storeu_ps(tC + 16, ymm6);
_mm256_storeu_ps(tC + 24, ymm7);
// multiply C by beta and accumulate, col 2.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
_mm256_storeu_ps(tC + 24, ymm11);
// multiply C by beta and accumulate, col 3.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
_mm256_storeu_ps(tC + 24, ymm15);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm11 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
ymm3 = _mm256_loadu_ps(tA + 24);
ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);
ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm11 = _mm256_mul_ps(ymm11, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate, col 1.
ymm2 = _mm256_loadu_ps(tC + 0);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
_mm256_storeu_ps(tC + 0, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
_mm256_storeu_ps(tC + 24, ymm11);
// multiply C by beta and accumulate, col 2.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
_mm256_storeu_ps(tC + 24, ymm15);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
ymm3 = _mm256_loadu_ps(tA + 24);
ymm15 = _mm256_fmadd_ps(ymm0, ymm3, ymm15);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC + 0);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
_mm256_storeu_ps(tC + 0, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
_mm256_storeu_ps(tC + 24, ymm15);
}
}
m_remainder = M - row_idx;
if (m_remainder >= 24)
{
m_remainder -= 24;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
_mm256_storeu_ps(tC + 16, ymm6);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
ymm10 = _mm256_mul_ps(ymm10, ymm0);
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC + 0);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
_mm256_storeu_ps(tC + 0, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
ymm3 = _mm256_loadu_ps(tA + 16);
ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC + 0);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
_mm256_storeu_ps(tC + 0, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
}
row_idx += 24;
}
if (m_remainder >= 16)
{
m_remainder -= 16;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
ymm8 = _mm256_fmadd_ps(ymm2, ymm3, ymm8);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(tC, ymm6);
_mm256_storeu_ps(tC + 8, ymm7);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(tC, ymm6);
_mm256_storeu_ps(tC + 8, ymm7);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm3 = _mm256_loadu_ps(tA + 8);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
}
row_idx += 16;
}
if (m_remainder >= 8)
{
m_remainder -= 8;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
_mm256_storeu_ps(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(tC, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
_mm256_storeu_ps(tC, ymm6);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
_mm256_storeu_ps(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(tC, ymm5);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_ps();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
ymm4 = _mm256_mul_ps(ymm4, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
_mm256_storeu_ps(tC, ymm4);
}
row_idx += 8;
}
// M is not a multiple of 32.
// The handling of edge case where the remainder
// dimension is less than 8. The padding takes place
// to handle this case.
if ((m_remainder) && (lda > 7))
{
float f_temp[8];
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm5 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_ps(tA);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
_mm256_storeu_ps(f_temp, ymm9);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm5 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
ymm3 = _mm256_loadu_ps(tA);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
tA += lda;
}
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm5 = _mm256_setzero_ps();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
ymm3 = _mm256_loadu_ps(tA);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
tA += lda;
}
ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm5 = _mm256_mul_ps(ymm5, ymm0);
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
m_remainder = 0;
}
if (m_remainder)
{
float result;
for (; row_idx < M; row_idx += 1)
{
for (col_idx = 0; col_idx < N; col_idx += 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
result = 0;
for (k = 0; k < K; ++k)
{
result += (*tA) * (*tB);
tA += lda;
tB += tb_inc_row;
}
result *= (*alpha_cast);
(*tC) = (*tC) * (*beta_cast) + result;
}
}
}
return BLIS_SUCCESS;
}
else
return BLIS_NONCONFORMAL_DIMENSIONS;
};
static err_t bli_dgemm_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
gint_t M = bli_obj_length( c ); // number of rows of Matrix C
gint_t N = bli_obj_width( c ); // number of columns of Matrix C
gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
gint_t L = M * N;
// If alpha is zero, scale by beta and return.
// printf("alpha_cast = %f beta_cast = %f [ Trans = %d %d], [stride = %d %d %d] [m,n,k = %d %d %d]\n",*alpha_cast,*beta_cast, bli_obj_has_trans( a ), bli_obj_has_trans( b ), lda, ldb,ldc, M,N,K);
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if( (L != 0) && (K != 0) && (N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME))
#else
if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES))
|| ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
#endif
{
guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
guint_t row_idx, col_idx, k;
double *A = a->buffer; // pointer to elements of Matrix A
double *B = b->buffer; // pointer to elements of Matrix B
double *C = c->buffer; // pointer to elements of Matrix C
double *tA = A, *tB = B, *tC = C;//, *tA_pack;
double *tA_packed; // temprorary pointer to hold packed A memory pointer
guint_t row_idx_packed; //packed A memory row index
guint_t lda_packed; //lda of packed A
guint_t col_idx_start; //starting index after A matrix is packed.
dim_t tb_inc_row = 1; // row stride of matrix B
dim_t tb_inc_col = ldb; // column stride of matrix B
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm0, ymm1, ymm2, ymm3;
gint_t n_remainder; // If the N is non multiple of 3.(N%3)
gint_t m_remainder; // If the M is non multiple of 16.(M%16)
double *alpha_cast, *beta_cast; // alpha, beta multiples
alpha_cast = (alpha->buffer);
beta_cast = (beta->buffer);
gint_t required_packing_A = 1;
// when N is equal to 1 call GEMV instead of GEMM
if (N == 1)
{
bli_gemv
(
alpha,
a,
b,
beta,
c
);
return BLIS_SUCCESS;
}
//update the pointer math if matrix B needs to be transposed.
if (bli_obj_has_trans( b ))
{
tb_inc_col = 1; //switch row and column strides
tb_inc_row = ldb;
}
if ((N <= 3) || ((D_MR * K) > D_SCRATCH_DIM))
{
required_packing_A = 0;
}
/*
* The computation loop runs for D_MRxN columns of C matrix, thus
* accessing the D_MRxK A matrix data and KxNR B matrix data.
* The computation is organized as inner loops of dimension D_MRxNR.
*/
// Process D_MR rows of C matrix at a time.
for (row_idx = 0; (row_idx + (D_MR - 1)) < M; row_idx += D_MR)
{
col_idx_start = 0;
tA_packed = A;
row_idx_packed = row_idx;
lda_packed = lda;
// This is the part of the pack and compute optimization.
// During the first column iteration, we store the accessed A matrix into
// contiguous static memory. This helps to keep te A matrix in Cache and
// aviods the TLB misses.
if (required_packing_A)
{
col_idx = 0;
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
tA_packed = D_A_pack;
#ifdef BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
// This loop is processing D_MR x K
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
_mm256_storeu_pd(tA_packed, ymm3); // the packing of matrix A
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
_mm256_storeu_pd(tA_packed + 4, ymm3); // the packing of matrix A
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
_mm256_storeu_pd(tA_packed + 8, ymm3); // the packing of matrix A
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
ymm3 = _mm256_loadu_pd(tA + 12);
_mm256_storeu_pd(tA_packed + 12, ymm3); // the packing of matrix A
// ymm7 += ymm0 * ymm3;
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
// ymm11 += ymm1 * ymm3;
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
// ymm15 += ymm2 * ymm3;
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
tA += lda;
tA_packed += D_MR;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
_mm256_storeu_pd(tC + 8, ymm6);
_mm256_storeu_pd(tC + 12, ymm7);
// multiply C by beta and accumulate, col 2.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);
_mm256_storeu_pd(tC, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
_mm256_storeu_pd(tC + 12, ymm11);
// multiply C by beta and accumulate, col 3.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
_mm256_storeu_pd(tC + 12, ymm15);
// modify the pointer arithematic to use packed A matrix.
col_idx_start = NR;
tA_packed = D_A_pack;
row_idx_packed = 0;
lda_packed = D_MR;
}
// Process NR columns of C matrix at a time.
for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = tA_packed + row_idx_packed;
#ifdef BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
// This loop is processing D_MR x K
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
ymm3 = _mm256_loadu_pd(tA + 12);
// ymm7 += ymm0 * ymm3;
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
// ymm11 += ymm1 * ymm3;
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
// ymm15 += ymm2 * ymm3;
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
tA += lda_packed;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
_mm256_storeu_pd(tC + 8, ymm6);
_mm256_storeu_pd(tC + 12, ymm7);
// multiply C by beta and accumulate, col 2.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);
_mm256_storeu_pd(tC, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
_mm256_storeu_pd(tC + 12, ymm11);
// multiply C by beta and accumulate, col 3.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
_mm256_storeu_pd(tC + 12, ymm15);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
ymm3 = _mm256_loadu_pd(tA + 12);
ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11);
ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
// multiply C by beta and accumulate, col 1.
ymm2 = _mm256_loadu_pd(tC + 0);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);
_mm256_storeu_pd(tC + 0, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
_mm256_storeu_pd(tC + 12, ymm11);
// multiply C by beta and accumulate, col 2.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
_mm256_storeu_pd(tC + 12, ymm15);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
ymm3 = _mm256_loadu_pd(tA + 12);
ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC + 0);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_pd(tC + 12);
ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
_mm256_storeu_pd(tC + 0, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
_mm256_storeu_pd(tC + 12, ymm15);
}
}
m_remainder = M - row_idx;
if (m_remainder >= 12)
{
m_remainder -= 12;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
// ymm4 += ymm0 * ymm3;
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
// ymm8 += ymm1 * ymm3;
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
// ymm12 += ymm2 * ymm3;
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
// ymm5 += ymm0 * ymm3;
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
// ymm9 += ymm1 * ymm3;
ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
// ymm13 += ymm2 * ymm3;
ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
// ymm6 += ymm0 * ymm3;
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
// ymm10 += ymm1 * ymm3;
ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
// ymm14 += ymm2 * ymm3;
ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
_mm256_storeu_pd(tC + 8, ymm6);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
_mm256_storeu_pd(tC, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC + 0);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
_mm256_storeu_pd(tC + 0, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
_mm256_storeu_pd(tC + 8, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
_mm256_storeu_pd(tC, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
ymm3 = _mm256_loadu_pd(tA + 8);
ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_pd(ymm12, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC + 0);
ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_pd(tC + 8);
ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
_mm256_storeu_pd(tC + 0, ymm12);
_mm256_storeu_pd(tC + 4, ymm13);
_mm256_storeu_pd(tC + 8, ymm14);
}
row_idx += 12;
}
if (m_remainder >= 8)
{
m_remainder -= 8;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
ymm8 = _mm256_mul_pd(ymm8, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
_mm256_storeu_pd(tC, ymm6);
_mm256_storeu_pd(tC + 4, ymm7);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
_mm256_storeu_pd(tC, ymm8);
_mm256_storeu_pd(tC + 4, ymm9);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
_mm256_storeu_pd(tC, ymm6);
_mm256_storeu_pd(tC + 4, ymm7);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm3 = _mm256_loadu_pd(tA + 4);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_pd(tC + 4);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
_mm256_storeu_pd(tC, ymm4);
_mm256_storeu_pd(tC + 4, ymm5);
}
row_idx += 8;
}
if (m_remainder >= 4)
{
//printf("HERE\n");
m_remainder -= 4;
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm6 = _mm256_mul_pd(ymm6, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
_mm256_storeu_pd(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
_mm256_storeu_pd(tC, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
_mm256_storeu_pd(tC, ymm6);
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_pd(ymm4, ymm0);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
_mm256_storeu_pd(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_pd(tC);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
_mm256_storeu_pd(tC, ymm5);
col_idx += 2;
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm4 = _mm256_setzero_pd();
for (k = 0; k < K; ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
ymm4 = _mm256_mul_pd(ymm4, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_pd(tC);
ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
_mm256_storeu_pd(tC, ymm4);
}
row_idx += 4;
}
// M is not a multiple of 32.
// The handling of edge case where the remainder
// dimension is less than 8. The padding takes place
// to handle this case.
if ((m_remainder) && (lda > 3))
{
double f_temp[8];
for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
// clear scratch registers.
ymm5 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
//broadcasted matrix B elements are multiplied
//with matrix A columns.
ymm3 = _mm256_loadu_pd(tA);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
tA += lda;
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
//multiply A*B by alpha.
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
_mm256_storeu_pd(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
_mm256_storeu_pd(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
_mm256_storeu_pd(f_temp, ymm9);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
n_remainder = N - col_idx;
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 2)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm5 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
ymm3 = _mm256_loadu_pd(tA);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
tA += lda;
}
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
ymm5 = _mm256_mul_pd(ymm5, ymm0);
ymm7 = _mm256_mul_pd(ymm7, ymm0);
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
_mm256_storeu_pd(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
tC += ldc;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
_mm256_storeu_pd(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
// if the N is not multiple of 3.
// handling edge case.
if (n_remainder == 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
ymm5 = _mm256_setzero_pd();
for (k = 0; k < (K - 1); ++k)
{
// The inner loop broadcasts the B matrix data and
// multiplies it with the A matrix.
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
ymm3 = _mm256_loadu_pd(tA);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
tA += lda;
}
ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
tB += tb_inc_row;
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tA[i];
}
ymm3 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
ymm0 = _mm256_broadcast_sd(alpha_cast);
ymm1 = _mm256_broadcast_sd(beta_cast);
// multiply C by beta and accumulate.
ymm5 = _mm256_mul_pd(ymm5, ymm0);
for (int i = 0; i < m_remainder; i++)
{
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_pd(f_temp);
ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
_mm256_storeu_pd(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
tC[i] = f_temp[i];
}
}
m_remainder = 0;
}
if (m_remainder)
{
double result;
for (; row_idx < M; row_idx += 1)
{
for (col_idx = 0; col_idx < N; col_idx += 1)
{
//pointer math to point to proper memory
tC = C + ldc * col_idx + row_idx;
tB = B + tb_inc_col * col_idx;
tA = A + row_idx;
result = 0;
for (k = 0; k < K; ++k)
{
result += (*tA) * (*tB);
tA += lda;
tB += tb_inc_row;
}
result *= (*alpha_cast);
(*tC) = (*tC) * (*beta_cast) + result;
}
}
}
return BLIS_SUCCESS;
}
else
return BLIS_NONCONFORMAL_DIMENSIONS;
};
static err_t bli_sgemm_small_atbn
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
gint_t M = bli_obj_length( c ); // number of rows of Matrix C
gint_t N = bli_obj_width( c ); // number of columns of Matrix C
gint_t K = bli_obj_length( b ); // number of rows of Matrix B
guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
int row_idx = 0, col_idx = 0, k;
float *A = a->buffer; // pointer to matrix A elements, stored in row major format
float *B = b->buffer; // pointer to matrix B elements, stored in column major format
float *C = c->buffer; // pointer to matrix C elements, stored in column major format
float *tA = A, *tB = B, *tC = C;
__m256 ymm4, ymm5, ymm6, ymm7;
__m256 ymm8, ymm9, ymm10, ymm11;
__m256 ymm12, ymm13, ymm14, ymm15;
__m256 ymm0, ymm1, ymm2, ymm3;
float result, scratch[8];
float *alpha_cast, *beta_cast; // alpha, beta multiples
alpha_cast = (alpha->buffer);
beta_cast = (beta->buffer);
// The non-copy version of the A^T GEMM gives better performance for the small M cases.
// The threshold is controlled by BLIS_ATBN_M_THRES
if (M <= BLIS_ATBN_M_THRES)
{
for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
{
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm5 = _mm256_setzero_ps();
ymm6 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm8 = _mm256_setzero_ps();
ymm9 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm11 = _mm256_setzero_ps();
ymm12 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
ymm14 = _mm256_setzero_ps();
ymm15 = _mm256_setzero_ps();
//The inner loop computes the 4x3 values of the matrix.
//The computation pattern is:
// ymm4 ymm5 ymm6
// ymm7 ymm8 ymm9
// ymm10 ymm11 ymm12
// ymm13 ymm14 ymm15
//The Dot operation is performed in the inner loop, 8 float elements fit
//in the YMM register hence loop count incremented by 8
for (k = 0; (k + 7) < K; k += 8)
{
ymm0 = _mm256_loadu_ps(tB + 0);
ymm1 = _mm256_loadu_ps(tB + ldb);
ymm2 = _mm256_loadu_ps(tB + 2 * ldb);
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
ymm3 = _mm256_loadu_ps(tA + lda);
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
ymm3 = _mm256_loadu_ps(tA + 2 * lda);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_ps(tA + 3 * lda);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
tA += 8;
tB += 8;
}
// if K is not a multiple of 8, padding is done before load using temproary array.
if (k < K)
{
int iter;
float data_feeder[8] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
ymm1 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
ymm2 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
}
//horizontal addition and storage of the data.
//Results for 4x3 blocks of C is stored here
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
_mm256_storeu_ps(scratch, ymm4);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
_mm256_storeu_ps(scratch, ymm7);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[1] = result + tC[1] * (*beta_cast);
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
_mm256_storeu_ps(scratch, ymm10);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[2] = result + tC[2] * (*beta_cast);
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
_mm256_storeu_ps(scratch, ymm13);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[3] = result + tC[3] * (*beta_cast);
tC += ldc;
ymm5 = _mm256_hadd_ps(ymm5, ymm5);
ymm5 = _mm256_hadd_ps(ymm5, ymm5);
_mm256_storeu_ps(scratch, ymm5);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
_mm256_storeu_ps(scratch, ymm8);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[1] = result + tC[1] * (*beta_cast);
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
_mm256_storeu_ps(scratch, ymm11);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[2] = result + tC[2] * (*beta_cast);
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
_mm256_storeu_ps(scratch, ymm14);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[3] = result + tC[3] * (*beta_cast);
tC += ldc;
ymm6 = _mm256_hadd_ps(ymm6, ymm6);
ymm6 = _mm256_hadd_ps(ymm6, ymm6);
_mm256_storeu_ps(scratch, ymm6);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
_mm256_storeu_ps(scratch, ymm9);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[1] = result + tC[1] * (*beta_cast);
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
_mm256_storeu_ps(scratch, ymm12);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[2] = result + tC[2] * (*beta_cast);
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
_mm256_storeu_ps(scratch, ymm15);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[3] = result + tC[3] * (*beta_cast);
}
}
int processed_col = col_idx;
int processed_row = row_idx;
//The edge case handling where N is not a multiple of 3
if (processed_col < N)
{
for (col_idx = processed_col; col_idx < N; col_idx += 1)
{
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
ymm7 = _mm256_setzero_ps();
ymm10 = _mm256_setzero_ps();
ymm13 = _mm256_setzero_ps();
//The inner loop computes the 4x1 values of the matrix.
//The computation pattern is:
// ymm4
// ymm7
// ymm10
// ymm13
for (k = 0; (k + 7) < K; k += 8)
{
ymm0 = _mm256_loadu_ps(tB + 0);
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
ymm3 = _mm256_loadu_ps(tA + lda);
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
ymm3 = _mm256_loadu_ps(tA + 2 * lda);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
ymm3 = _mm256_loadu_ps(tA + 3 * lda);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
tA += 8;
tB += 8;
}
// if K is not a multiple of 8, padding is done before load using temproary array.
if (k < K)
{
int iter;
float data_feeder[8] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
}
//horizontal addition and storage of the data.
//Results for 4x1 blocks of C is stored here
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
_mm256_storeu_ps(scratch, ymm4);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
_mm256_storeu_ps(scratch, ymm7);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[1] = result + tC[1] * (*beta_cast);
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
_mm256_storeu_ps(scratch, ymm10);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[2] = result + tC[2] * (*beta_cast);
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
_mm256_storeu_ps(scratch, ymm13);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[3] = result + tC[3] * (*beta_cast);
}
}
processed_row = row_idx;
}
//The edge case handling where M is not a multiple of 4
if (processed_row < M)
{
for (row_idx = processed_row; row_idx < M; row_idx += 1)
{
for (col_idx = 0; col_idx < N; col_idx += 1)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_ps();
for (k = 0; (k + 7) < K; k += 8)
{
ymm0 = _mm256_loadu_ps(tB + 0);
ymm3 = _mm256_loadu_ps(tA);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
tA += 8;
tB += 8;
}
// if K is not a multiple of 8, padding is done before load using temproary array.
if (k < K)
{
int iter;
float data_feeder[8] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_ps(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_ps(data_feeder);
ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
}
//horizontal addition and storage of the data.
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
ymm4 = _mm256_hadd_ps(ymm4, ymm4);
_mm256_storeu_ps(scratch, ymm4);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
}
}
}
return BLIS_SUCCESS;
}
else
return BLIS_NONCONFORMAL_DIMENSIONS;
}
static err_t bli_dgemm_small_atbn
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
)
{
gint_t M = bli_obj_length( c ); // number of rows of Matrix C
gint_t N = bli_obj_width( c ); // number of columns of Matrix C
gint_t K = bli_obj_length( b ); // number of rows of Matrix B
guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
guint_t row_idx = 0, col_idx = 0, k;
double *A = a->buffer; // pointer to matrix A elements, stored in row major format
double *B = b->buffer; // pointer to matrix B elements, stored in column major format
double *C = c->buffer; // pointer to matrix C elements, stored in column major format
double *tA = A, *tB = B, *tC = C;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm0, ymm1, ymm2, ymm3;
double result, scratch[8];
double *alpha_cast, *beta_cast; // alpha, beta multiples
alpha_cast = (alpha->buffer);
beta_cast = (beta->buffer);
// The non-copy version of the A^T GEMM gives better performance for the small M cases.
// The threshold is controlled by BLIS_ATBN_M_THRES
if (M <= BLIS_ATBN_M_THRES)
{
for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
{
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
//The inner loop computes the 4x3 values of the matrix.
//The computation pattern is:
// ymm4 ymm5 ymm6
// ymm7 ymm8 ymm9
// ymm10 ymm11 ymm12
// ymm13 ymm14 ymm15
//The Dot operation is performed in the inner loop, 4 double elements fit
//in the YMM register hence loop count incremented by 4
for (k = 0; (k + 3) < K; k += 4)
{
ymm0 = _mm256_loadu_pd(tB + 0);
ymm1 = _mm256_loadu_pd(tB + ldb);
ymm2 = _mm256_loadu_pd(tB + 2 * ldb);
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
ymm3 = _mm256_loadu_pd(tA + lda);
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
ymm3 = _mm256_loadu_pd(tA + 2 * lda);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
ymm3 = _mm256_loadu_pd(tA + 3 * lda);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
tA += 4;
tB += 4;
}
// if K is not a multiple of 4, padding is done before load using temproary array.
if (k < K)
{
int iter;
double data_feeder[4] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
ymm1 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
ymm2 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
}
//horizontal addition and storage of the data.
//Results for 4x3 blocks of C is stored here
ymm4 = _mm256_hadd_pd(ymm4, ymm4);
_mm256_storeu_pd(scratch, ymm4);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
ymm7 = _mm256_hadd_pd(ymm7, ymm7);
_mm256_storeu_pd(scratch, ymm7);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[1] = result + tC[1] * (*beta_cast);
ymm10 = _mm256_hadd_pd(ymm10, ymm10);
_mm256_storeu_pd(scratch, ymm10);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[2] = result + tC[2] * (*beta_cast);
ymm13 = _mm256_hadd_pd(ymm13, ymm13);
_mm256_storeu_pd(scratch, ymm13);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[3] = result + tC[3] * (*beta_cast);
tC += ldc;
ymm5 = _mm256_hadd_pd(ymm5, ymm5);
_mm256_storeu_pd(scratch, ymm5);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
ymm8 = _mm256_hadd_pd(ymm8, ymm8);
_mm256_storeu_pd(scratch, ymm8);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[1] = result + tC[1] * (*beta_cast);
ymm11 = _mm256_hadd_pd(ymm11, ymm11);
_mm256_storeu_pd(scratch, ymm11);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[2] = result + tC[2] * (*beta_cast);
ymm14 = _mm256_hadd_pd(ymm14, ymm14);
_mm256_storeu_pd(scratch, ymm14);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[3] = result + tC[3] * (*beta_cast);
tC += ldc;
ymm6 = _mm256_hadd_pd(ymm6, ymm6);
_mm256_storeu_pd(scratch, ymm6);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
ymm9 = _mm256_hadd_pd(ymm9, ymm9);
_mm256_storeu_pd(scratch, ymm9);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[1] = result + tC[1] * (*beta_cast);
ymm12 = _mm256_hadd_pd(ymm12, ymm12);
_mm256_storeu_pd(scratch, ymm12);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[2] = result + tC[2] * (*beta_cast);
ymm15 = _mm256_hadd_pd(ymm15, ymm15);
_mm256_storeu_pd(scratch, ymm15);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[3] = result + tC[3] * (*beta_cast);
}
}
int processed_col = col_idx;
int processed_row = row_idx;
//The edge case handling where N is not a multiple of 3
if (processed_col < N)
{
for (col_idx = processed_col; col_idx < N; col_idx += 1)
{
for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
//The inner loop computes the 4x1 values of the matrix.
//The computation pattern is:
// ymm4
// ymm7
// ymm10
// ymm13
for (k = 0; (k + 3) < K; k += 4)
{
ymm0 = _mm256_loadu_pd(tB + 0);
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
ymm3 = _mm256_loadu_pd(tA + lda);
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
ymm3 = _mm256_loadu_pd(tA + 2 * lda);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
ymm3 = _mm256_loadu_pd(tA + 3 * lda);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
tA += 4;
tB += 4;
}
// if K is not a multiple of 4, padding is done before load using temproary array.
if (k < K)
{
int iter;
double data_feeder[4] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
}
//horizontal addition and storage of the data.
//Results for 4x1 blocks of C is stored here
ymm4 = _mm256_hadd_pd(ymm4, ymm4);
_mm256_storeu_pd(scratch, ymm4);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
ymm7 = _mm256_hadd_pd(ymm7, ymm7);
_mm256_storeu_pd(scratch, ymm7);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[1] = result + tC[1] * (*beta_cast);
ymm10 = _mm256_hadd_pd(ymm10, ymm10);
_mm256_storeu_pd(scratch, ymm10);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[2] = result + tC[2] * (*beta_cast);
ymm13 = _mm256_hadd_pd(ymm13, ymm13);
_mm256_storeu_pd(scratch, ymm13);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[3] = result + tC[3] * (*beta_cast);
}
}
processed_row = row_idx;
}
// The edge case handling where M is not a multiple of 4
if (processed_row < M)
{
for (row_idx = processed_row; row_idx < M; row_idx += 1)
{
for (col_idx = 0; col_idx < N; col_idx += 1)
{
tA = A + row_idx * lda;
tB = B + col_idx * ldb;
tC = C + col_idx * ldc + row_idx;
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
for (k = 0; (k + 3) < K; k += 4)
{
ymm0 = _mm256_loadu_pd(tB + 0);
ymm3 = _mm256_loadu_pd(tA);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
tA += 4;
tB += 4;
}
// if K is not a multiple of 4, padding is done before load using temproary array.
if (k < K)
{
int iter;
double data_feeder[4] = { 0.0 };
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
ymm0 = _mm256_loadu_pd(data_feeder);
for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
ymm3 = _mm256_loadu_pd(data_feeder);
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
}
//horizontal addition and storage of the data.
ymm4 = _mm256_hadd_pd(ymm4, ymm4);
_mm256_storeu_pd(scratch, ymm4);
result = scratch[0] + scratch[2];
result *= (*alpha_cast);
tC[0] = result + tC[0] * (*beta_cast);
}
}
}
return BLIS_SUCCESS;
}
else
return BLIS_NONCONFORMAL_DIMENSIONS;
}
#endif