Added 16x4 AXPYF kernel for zen2 config

Details:
- Added a new AXPYF kernel with fuse_factor = 4 and iter_unroll = 4.
- Modified blas interface of GEMM to call GEMV whenever m=1 or n=1.

Change-Id: I3f5acd37b009f53cf63f462cec79fd3e73676dbc
This commit is contained in:
Meghana Vankadari
2021-01-21 17:35:50 +05:30
parent 48f2366b6f
commit 2e7cf8d82f
5 changed files with 817 additions and 11 deletions

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -52,6 +52,8 @@ void PASTEMAC(ch,varname) \
cntx_t* cntx \
) \
{ \
\
if(cntx == NULL) cntx = bli_gks_query_cntx(); \
\
const num_t dt = PASTEMAC(ch,type); \
\

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -177,8 +177,8 @@ void bli_dgemv_unf_var2
NULL
);
/* Query the context for the kernel function pointer and fusing factor. */
b_fuse = 5;
/* Fusing factor. */
b_fuse = 4;
for ( i = 0; i < n_iter; i += f )
{
@@ -189,7 +189,7 @@ void bli_dgemv_unf_var2
y1 = y + (0 )*incy;
/* y = y + alpha * A1 * x1; */
bli_daxpyf_zen_int_5
bli_daxpyf_zen_int_16x4
(
conja,
conjx,

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2019, Advanced Micro Devices, Inc.
Copyright (C) 2019 - 21, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -149,6 +149,9 @@ void PASTEF77(ch,blasname) \
trans_t blis_transb; \
dim_t m0, n0, k0; \
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \
\
dim_t m0_a, n0_a; \
dim_t m0_b, n0_b; \
\
/* Initialize BLIS. */ \
bli_init_auto(); \
@@ -184,6 +187,71 @@ void PASTEF77(ch,blasname) \
const inc_t cs_b = *ldb; \
const inc_t rs_c = 1; \
const inc_t cs_c = *ldc; \
\
if( n0 == 1 ) \
{ \
if(bli_is_notrans(blis_transa)) \
{ \
PASTEMAC(ch,gemv_unf_var2)( \
BLIS_NO_TRANSPOSE, \
bli_extract_conj(blis_transb), \
m0, k0, \
(ftype*)alpha, \
(ftype*)a, rs_a, cs_a,\
(ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \
(ftype*) beta, \
c, rs_c, \
NULL \
); \
} \
else \
{ \
PASTEMAC(ch,gemv_unf_var1)( \
blis_transa, \
bli_extract_conj(blis_transb), \
k0, m0, \
(ftype*)alpha, \
(ftype*)a, rs_a, cs_a, \
(ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \
(ftype*)beta, \
c, rs_c, \
NULL \
); \
} \
return; \
} \
else if( m0 == 1 ) \
{ \
if(bli_is_notrans(blis_transb)) \
{ \
PASTEMAC(ch,gemv_unf_var1)( \
blis_transb, \
bli_extract_conj(blis_transa), \
n0, k0, \
(ftype*)alpha, \
(ftype*)b, cs_b, rs_b, \
(ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \
(ftype*)beta, \
c, cs_c, \
NULL \
); \
} \
else \
{ \
PASTEMAC(ch,gemv_unf_var2)( \
blis_transb, \
bli_extract_conj(blis_transa), \
k0, n0, \
(ftype*)alpha, \
(ftype*)b, cs_b, rs_b, \
(ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \
(ftype*)beta, \
c, cs_c, \
NULL \
); \
} \
return; \
} \
\
const num_t dt = PASTEMAC(ch,type); \
\
@@ -192,9 +260,6 @@ void PASTEF77(ch,blasname) \
obj_t bo = BLIS_OBJECT_INITIALIZER; \
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \
obj_t co = BLIS_OBJECT_INITIALIZER; \
\
dim_t m0_a, n0_a; \
dim_t m0_b, n0_b; \
\
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -48,9 +48,16 @@ typedef union
typedef union
{
__m256d v;
__m128d xmm[2];
double d[4] __attribute__((aligned(64)));
} v4df_t;
typedef union
{
__m128d v;
double d[2] __attribute__((aligned(64)));
} v2df_t;
void bli_saxpyf_zen_int_5
(
@@ -597,6 +604,737 @@ void bli_daxpyf_zen_int_5
}
}
// -----------------------------------------------------------------------------
static void bli_daxpyf_zen_int_16x2
(
conj_t conja,
conj_t conjx,
dim_t m,
dim_t b_n,
double* restrict alpha,
double* restrict a, inc_t inca, inc_t lda,
double* restrict x, inc_t incx,
double* restrict y, inc_t incy,
cntx_t* restrict cntx
)
{
const dim_t fuse_fac = 2;
const dim_t n_elem_per_reg = 4;
const dim_t n_iter_unroll = 4;
dim_t i;
double* restrict a0;
double* restrict a1;
double* restrict y0;
v4df_t chi0v, chi1v;
v4df_t a00v, a01v;
v4df_t a10v, a11v;
v4df_t a20v, a21v;
v4df_t a30v, a31v;
v4df_t y0v, y1v, y2v, y3v;
double chi0, chi1;
v2df_t a40v, a41v;
v2df_t y4v;
// If either dimension is zero, or if alpha is zero, return early.
if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return;
// If b_n is not equal to the fusing factor, then perform the entire
// operation as a loop over axpyv.
if ( b_n != fuse_fac )
{
#ifdef BLIS_CONFIG_EPYC
for ( i = 0; i < b_n; ++i )
{
double* a1 = a + (0 )*inca + (i )*lda;
double* chi1 = x + (i )*incx;
double* y1 = y + (0 )*incy;
double alpha_chi1;
bli_dcopycjs( conjx, *chi1, alpha_chi1 );
bli_dscals( *alpha, alpha_chi1 );
bli_daxpyv_zen_int10
(
conja,
m,
&alpha_chi1,
a1, inca,
y1, incy,
cntx
);
}
#else
daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx );
for ( i = 0; i < b_n; ++i )
{
double* a1 = a + (0 )*inca + (i )*lda;
double* chi1 = x + (i )*incx;
double* y1 = y + (0 )*incy;
double alpha_chi1;
bli_dcopycjs( conjx, *chi1, alpha_chi1 );
bli_dscals( *alpha, alpha_chi1 );
f
(
conja,
m,
&alpha_chi1,
a1, inca,
y1, incy,
cntx
);
}
#endif
return;
}
// At this point, we know that b_n is exactly equal to the fusing factor.
a0 = a + 0*lda;
a1 = a + 1*lda;
y0 = y;
chi0 = *( x + 0*incx );
chi1 = *( x + 1*incx );
// Scale each chi scalar by alpha.
bli_dscals( *alpha, chi0 );
bli_dscals( *alpha, chi1 );
// Broadcast the (alpha*chi?) scalars to all elements of vector registers.
chi0v.v = _mm256_broadcast_sd( &chi0 );
chi1v.v = _mm256_broadcast_sd( &chi1 );
// If there are vectorized iterations, perform them with vector
// instructions.
if ( inca == 1 && incy == 1 )
{
for ( i = 0; (i + 15) < m; i += 16 )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg );
a30v.v = _mm256_loadu_pd( a0 + 3*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg );
a31v.v = _mm256_loadu_pd( a1 + 3*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v );
y3v.v = _mm256_fmadd_pd( a30v.v, chi0v.v, y3v.v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v );
y3v.v = _mm256_fmadd_pd( a31v.v, chi1v.v, y3v.v );
// Store the output.
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
_mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v );
_mm256_storeu_pd( (double *)(y0 + 3*n_elem_per_reg), y3v.v );
y0 += n_iter_unroll * n_elem_per_reg;
a0 += n_iter_unroll * n_elem_per_reg;
a1 += n_iter_unroll * n_elem_per_reg;
}
for ( ; (i + 11) < m; i += 12 )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v );
// Store the output.
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
_mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v );
y0 += 3 * n_elem_per_reg;
a0 += 3 * n_elem_per_reg;
a1 += 3 * n_elem_per_reg;
}
for ( ; (i + 7) < m; i += 8 )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
// Store the output.
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
y0 += 2 * n_elem_per_reg;
a0 += 2 * n_elem_per_reg;
a1 += 2 * n_elem_per_reg;
}
for ( ; (i + 3) < m; i += 4 )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
// Store the output.
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
y0 += n_elem_per_reg;
a0 += n_elem_per_reg;
a1 += n_elem_per_reg;
}
for ( ; (i + 1) < m; i += 2 )
{
// Load the input values.
y4v.v = _mm_loadu_pd( y0 + 0*n_elem_per_reg );
a40v.v = _mm_loadu_pd( a0 + 0*n_elem_per_reg );
a41v.v = _mm_loadu_pd( a1 + 0*n_elem_per_reg );
// perform : y += alpha * x;
y4v.v = _mm_fmadd_pd( a40v.v, chi0v.xmm[0], y4v.v );
y4v.v = _mm_fmadd_pd( a41v.v, chi1v.xmm[0], y4v.v );
// Store the output.
_mm_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y4v.v );
y0 += 2;
a0 += 2;
a1 += 2;
}
// If there are leftover iterations, perform them with scalar code.
for ( ; (i + 0) < m ; ++i )
{
double y0c = *y0;
const double a0c = *a0;
const double a1c = *a1;
y0c += chi0 * a0c;
y0c += chi1 * a1c;
*y0 = y0c;
a0 += 1;
a1 += 1;
y0 += 1;
}
}
else
{
for ( i = 0; (i + 0) < m ; ++i )
{
double y0c = *y0;
const double a0c = *a0;
const double a1c = *a1;
y0c += chi0 * a0c;
y0c += chi1 * a1c;
*y0 = y0c;
a0 += inca;
a1 += inca;
y0 += incy;
}
}
}
// -----------------------------------------------------------------------------
void bli_daxpyf_zen_int_16x4
(
conj_t conja,
conj_t conjx,
dim_t m,
dim_t b_n,
double* restrict alpha,
double* restrict a, inc_t inca, inc_t lda,
double* restrict x, inc_t incx,
double* restrict y, inc_t incy,
cntx_t* restrict cntx
)
{
const dim_t fuse_fac = 4;
const dim_t n_elem_per_reg = 4;
const dim_t n_iter_unroll = 4;
dim_t i;
double* restrict a0;
double* restrict a1;
double* restrict a2;
double* restrict a3;
double* restrict y0;
v4df_t chi0v, chi1v, chi2v, chi3v;
v4df_t a00v, a01v, a02v, a03v;
v4df_t a10v, a11v, a12v, a13v;
v4df_t a20v, a21v, a22v, a23v;
v4df_t a30v, a31v, a32v, a33v;
v4df_t y0v, y1v, y2v, y3v;
double chi0, chi1, chi2, chi3;
v2df_t y4v;
v2df_t a40v, a41v, a42v, a43v;
// If either dimension is zero, or if alpha is zero, return early.
if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return;
// If b_n is not equal to the fusing factor, then perform the entire
// operation as a loop over axpyv.
if ( b_n != fuse_fac )
{
#ifdef BLIS_CONFIG_EPYC
if(b_n & 2)
{
bli_daxpyf_zen_int_16x2( conja,
conjx,
m, 2,
alpha, a, inca, lda,
x, incx,
y, incy,
cntx
);
b_n -= 2;
a += 2*lda;
x += 2 * incx;
}
for ( i = 0; i < b_n; ++i )
{
double* a1 = a + (0 )*inca + (i )*lda;
double* chi1 = x + (i )*incx;
double* y1 = y + (0 )*incy;
double alpha_chi1;
bli_dcopycjs( conjx, *chi1, alpha_chi1 );
bli_dscals( *alpha, alpha_chi1 );
bli_daxpyv_zen_int10
(
conja,
m,
&alpha_chi1,
a1, inca,
y1, incy,
cntx
);
}
#else
daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx );
for ( i = 0; i < b_n; ++i )
{
double* a1 = a + (0 )*inca + (i )*lda;
double* chi1 = x + (i )*incx;
double* y1 = y + (0 )*incy;
double alpha_chi1;
bli_dcopycjs( conjx, *chi1, alpha_chi1 );
bli_dscals( *alpha, alpha_chi1 );
f
(
conja,
m,
&alpha_chi1,
a1, inca,
y1, incy,
cntx
);
}
#endif
return;
}
// At this point, we know that b_n is exactly equal to the fusing factor.
a0 = a + 0*lda;
a1 = a + 1*lda;
a2 = a + 2*lda;
a3 = a + 3*lda;
y0 = y;
chi0 = *( x + 0*incx );
chi1 = *( x + 1*incx );
chi2 = *( x + 2*incx );
chi3 = *( x + 3*incx );
// Scale each chi scalar by alpha.
bli_dscals( *alpha, chi0 );
bli_dscals( *alpha, chi1 );
bli_dscals( *alpha, chi2 );
bli_dscals( *alpha, chi3 );
// Broadcast the (alpha*chi?) scalars to all elements of vector registers.
chi0v.v = _mm256_broadcast_sd( &chi0 );
chi1v.v = _mm256_broadcast_sd( &chi1 );
chi2v.v = _mm256_broadcast_sd( &chi2 );
chi3v.v = _mm256_broadcast_sd( &chi3 );
// If there are vectorized iterations, perform them with vector
// instructions.
if ( inca == 1 && incy == 1 )
{
for ( i = 0; (i + 15) < m; i += 16 )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg );
a30v.v = _mm256_loadu_pd( a0 + 3*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg );
a31v.v = _mm256_loadu_pd( a1 + 3*n_elem_per_reg );
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg );
a22v.v = _mm256_loadu_pd( a2 + 2*n_elem_per_reg );
a32v.v = _mm256_loadu_pd( a2 + 3*n_elem_per_reg );
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg );
a23v.v = _mm256_loadu_pd( a3 + 2*n_elem_per_reg );
a33v.v = _mm256_loadu_pd( a3 + 3*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v );
y3v.v = _mm256_fmadd_pd( a30v.v, chi0v.v, y3v.v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v );
y3v.v = _mm256_fmadd_pd( a31v.v, chi1v.v, y3v.v );
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a22v.v, chi2v.v, y2v.v );
y3v.v = _mm256_fmadd_pd( a32v.v, chi2v.v, y3v.v );
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a23v.v, chi3v.v, y2v.v );
y3v.v = _mm256_fmadd_pd( a33v.v, chi3v.v, y3v.v );
// Store the output.
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
_mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v );
_mm256_storeu_pd( (double *)(y0 + 3*n_elem_per_reg), y3v.v );
y0 += n_iter_unroll * n_elem_per_reg;
a0 += n_iter_unroll * n_elem_per_reg;
a1 += n_iter_unroll * n_elem_per_reg;
a2 += n_iter_unroll * n_elem_per_reg;
a3 += n_iter_unroll * n_elem_per_reg;
}
for ( ; (i + 11) < m; i += 12 )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg );
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg );
a22v.v = _mm256_loadu_pd( a2 + 2*n_elem_per_reg );
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg );
a23v.v = _mm256_loadu_pd( a3 + 2*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v );
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a22v.v, chi2v.v, y2v.v );
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v );
y2v.v = _mm256_fmadd_pd( a23v.v, chi3v.v, y2v.v );
// Store the output.
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
_mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v );
y0 += 3 * n_elem_per_reg;
a0 += 3 * n_elem_per_reg;
a1 += 3 * n_elem_per_reg;
a2 += 3 * n_elem_per_reg;
a3 += 3 * n_elem_per_reg;
}
for ( ; (i + 7) < m; i += 8 )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg );
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v );
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v );
// Store the output.
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
y0 += 2 * n_elem_per_reg;
a0 += 2 * n_elem_per_reg;
a1 += 2 * n_elem_per_reg;
a2 += 2 * n_elem_per_reg;
a3 += 2 * n_elem_per_reg;
}
for ( ; (i + 3) < m; i += 4)
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
// Store the output.
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
y0 += n_elem_per_reg;
a0 += n_elem_per_reg;
a1 += n_elem_per_reg;
a2 += n_elem_per_reg;
a3 += n_elem_per_reg;
}
#if 1
for ( ; (i + 1) < m; i += 2)
{
// Load the input values.
y4v.v = _mm_loadu_pd( y0 + 0*n_elem_per_reg );
a40v.v = _mm_loadu_pd( a0 + 0*n_elem_per_reg );
a41v.v = _mm_loadu_pd( a1 + 0*n_elem_per_reg );
a42v.v = _mm_loadu_pd( a2 + 0*n_elem_per_reg );
a43v.v = _mm_loadu_pd( a3 + 0*n_elem_per_reg );
// perform : y += alpha * x;
y4v.v = _mm_fmadd_pd( a40v.v, chi0v.xmm[0], y4v.v );
y4v.v = _mm_fmadd_pd( a41v.v, chi1v.xmm[0], y4v.v );
y4v.v = _mm_fmadd_pd( a42v.v, chi2v.xmm[0], y4v.v );
y4v.v = _mm_fmadd_pd( a43v.v, chi3v.xmm[0], y4v.v );
// Store the output.
_mm_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y4v.v );
y0 += 2;
a0 += 2;
a1 += 2;
a2 += 2;
a3 += 2;
}
#endif
// If there are leftover iterations, perform them with scalar code.
for ( ; (i + 0) < m ; ++i )
{
double y0c = *y0;
const double a0c = *a0;
const double a1c = *a1;
const double a2c = *a2;
const double a3c = *a3;
y0c += chi0 * a0c;
y0c += chi1 * a1c;
y0c += chi2 * a2c;
y0c += chi3 * a3c;
*y0 = y0c;
a0 += 1;
a1 += 1;
a2 += 1;
a3 += 1;
y0 += 1;
}
}
else
{
for ( i = 0; (i + 0) < m ; ++i )
{
double y0c = *y0;
const double a0c = *a0;
const double a1c = *a1;
const double a2c = *a2;
const double a3c = *a3;
y0c += chi0 * a0c;
y0c += chi1 * a1c;
y0c += chi2 * a2c;
y0c += chi3 * a3c;
*y0 = y0c;
a0 += inca;
a1 += inca;
a2 += inca;
a3 += inca;
y0 += incy;
}
}
}
// -----------------------------------------------------------------------------

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -95,6 +95,7 @@ SETV_KER_PROT(double, d, setv_zen_int)
// axpyf (intrinsics)
AXPYF_KER_PROT( float, s, axpyf_zen_int_8 )
AXPYF_KER_PROT( double, d, axpyf_zen_int_8 )
AXPYF_KER_PROT( double, d, axpyf_zen_int_16x4 )
AXPYF_KER_PROT( float, s, axpyf_zen_int_5 )
AXPYF_KER_PROT( double, d, axpyf_zen_int_5 )