mirror of
https://github.com/amd/blis.git
synced 2026-05-11 17:50:00 +00:00
Added 16x4 AXPYF kernel for zen2 config
Details: - Added a new AXPYF kernel with fuse_factor = 4 and iter_unroll = 4. - Modified blas interface of GEMM to call GEMV whenever m=1 or n=1. Change-Id: I3f5acd37b009f53cf63f462cec79fd3e73676dbc
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -52,6 +52,8 @@ void PASTEMAC(ch,varname) \
|
||||
cntx_t* cntx \
|
||||
) \
|
||||
{ \
|
||||
\
|
||||
if(cntx == NULL) cntx = bli_gks_query_cntx(); \
|
||||
\
|
||||
const num_t dt = PASTEMAC(ch,type); \
|
||||
\
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -177,8 +177,8 @@ void bli_dgemv_unf_var2
|
||||
NULL
|
||||
);
|
||||
|
||||
/* Query the context for the kernel function pointer and fusing factor. */
|
||||
b_fuse = 5;
|
||||
/* Fusing factor. */
|
||||
b_fuse = 4;
|
||||
|
||||
for ( i = 0; i < n_iter; i += f )
|
||||
{
|
||||
@@ -189,7 +189,7 @@ void bli_dgemv_unf_var2
|
||||
y1 = y + (0 )*incy;
|
||||
|
||||
/* y = y + alpha * A1 * x1; */
|
||||
bli_daxpyf_zen_int_5
|
||||
bli_daxpyf_zen_int_16x4
|
||||
(
|
||||
conja,
|
||||
conjx,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2019, Advanced Micro Devices, Inc.
|
||||
Copyright (C) 2019 - 21, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -149,6 +149,9 @@ void PASTEF77(ch,blasname) \
|
||||
trans_t blis_transb; \
|
||||
dim_t m0, n0, k0; \
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \
|
||||
\
|
||||
dim_t m0_a, n0_a; \
|
||||
dim_t m0_b, n0_b; \
|
||||
\
|
||||
/* Initialize BLIS. */ \
|
||||
bli_init_auto(); \
|
||||
@@ -184,6 +187,71 @@ void PASTEF77(ch,blasname) \
|
||||
const inc_t cs_b = *ldb; \
|
||||
const inc_t rs_c = 1; \
|
||||
const inc_t cs_c = *ldc; \
|
||||
\
|
||||
if( n0 == 1 ) \
|
||||
{ \
|
||||
if(bli_is_notrans(blis_transa)) \
|
||||
{ \
|
||||
PASTEMAC(ch,gemv_unf_var2)( \
|
||||
BLIS_NO_TRANSPOSE, \
|
||||
bli_extract_conj(blis_transb), \
|
||||
m0, k0, \
|
||||
(ftype*)alpha, \
|
||||
(ftype*)a, rs_a, cs_a,\
|
||||
(ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \
|
||||
(ftype*) beta, \
|
||||
c, rs_c, \
|
||||
NULL \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
PASTEMAC(ch,gemv_unf_var1)( \
|
||||
blis_transa, \
|
||||
bli_extract_conj(blis_transb), \
|
||||
k0, m0, \
|
||||
(ftype*)alpha, \
|
||||
(ftype*)a, rs_a, cs_a, \
|
||||
(ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \
|
||||
(ftype*)beta, \
|
||||
c, rs_c, \
|
||||
NULL \
|
||||
); \
|
||||
} \
|
||||
return; \
|
||||
} \
|
||||
else if( m0 == 1 ) \
|
||||
{ \
|
||||
if(bli_is_notrans(blis_transb)) \
|
||||
{ \
|
||||
PASTEMAC(ch,gemv_unf_var1)( \
|
||||
blis_transb, \
|
||||
bli_extract_conj(blis_transa), \
|
||||
n0, k0, \
|
||||
(ftype*)alpha, \
|
||||
(ftype*)b, cs_b, rs_b, \
|
||||
(ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \
|
||||
(ftype*)beta, \
|
||||
c, cs_c, \
|
||||
NULL \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
PASTEMAC(ch,gemv_unf_var2)( \
|
||||
blis_transb, \
|
||||
bli_extract_conj(blis_transa), \
|
||||
k0, n0, \
|
||||
(ftype*)alpha, \
|
||||
(ftype*)b, cs_b, rs_b, \
|
||||
(ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \
|
||||
(ftype*)beta, \
|
||||
c, cs_c, \
|
||||
NULL \
|
||||
); \
|
||||
} \
|
||||
return; \
|
||||
} \
|
||||
\
|
||||
const num_t dt = PASTEMAC(ch,type); \
|
||||
\
|
||||
@@ -192,9 +260,6 @@ void PASTEF77(ch,blasname) \
|
||||
obj_t bo = BLIS_OBJECT_INITIALIZER; \
|
||||
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \
|
||||
obj_t co = BLIS_OBJECT_INITIALIZER; \
|
||||
\
|
||||
dim_t m0_a, n0_a; \
|
||||
dim_t m0_b, n0_b; \
|
||||
\
|
||||
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \
|
||||
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -48,9 +48,16 @@ typedef union
|
||||
typedef union
|
||||
{
|
||||
__m256d v;
|
||||
__m128d xmm[2];
|
||||
double d[4] __attribute__((aligned(64)));
|
||||
} v4df_t;
|
||||
|
||||
typedef union
|
||||
{
|
||||
__m128d v;
|
||||
double d[2] __attribute__((aligned(64)));
|
||||
} v2df_t;
|
||||
|
||||
|
||||
void bli_saxpyf_zen_int_5
|
||||
(
|
||||
@@ -597,6 +604,737 @@ void bli_daxpyf_zen_int_5
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
static void bli_daxpyf_zen_int_16x2
|
||||
(
|
||||
conj_t conja,
|
||||
conj_t conjx,
|
||||
dim_t m,
|
||||
dim_t b_n,
|
||||
double* restrict alpha,
|
||||
double* restrict a, inc_t inca, inc_t lda,
|
||||
double* restrict x, inc_t incx,
|
||||
double* restrict y, inc_t incy,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
const dim_t fuse_fac = 2;
|
||||
|
||||
const dim_t n_elem_per_reg = 4;
|
||||
const dim_t n_iter_unroll = 4;
|
||||
|
||||
dim_t i;
|
||||
|
||||
double* restrict a0;
|
||||
double* restrict a1;
|
||||
|
||||
double* restrict y0;
|
||||
|
||||
v4df_t chi0v, chi1v;
|
||||
|
||||
v4df_t a00v, a01v;
|
||||
|
||||
v4df_t a10v, a11v;
|
||||
|
||||
v4df_t a20v, a21v;
|
||||
|
||||
v4df_t a30v, a31v;
|
||||
|
||||
v4df_t y0v, y1v, y2v, y3v;
|
||||
|
||||
double chi0, chi1;
|
||||
|
||||
v2df_t a40v, a41v;
|
||||
|
||||
v2df_t y4v;
|
||||
// If either dimension is zero, or if alpha is zero, return early.
|
||||
if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return;
|
||||
|
||||
// If b_n is not equal to the fusing factor, then perform the entire
|
||||
// operation as a loop over axpyv.
|
||||
if ( b_n != fuse_fac )
|
||||
{
|
||||
#ifdef BLIS_CONFIG_EPYC
|
||||
for ( i = 0; i < b_n; ++i )
|
||||
{
|
||||
double* a1 = a + (0 )*inca + (i )*lda;
|
||||
double* chi1 = x + (i )*incx;
|
||||
double* y1 = y + (0 )*incy;
|
||||
double alpha_chi1;
|
||||
|
||||
bli_dcopycjs( conjx, *chi1, alpha_chi1 );
|
||||
bli_dscals( *alpha, alpha_chi1 );
|
||||
|
||||
bli_daxpyv_zen_int10
|
||||
(
|
||||
conja,
|
||||
m,
|
||||
&alpha_chi1,
|
||||
a1, inca,
|
||||
y1, incy,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
|
||||
#else
|
||||
daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx );
|
||||
|
||||
for ( i = 0; i < b_n; ++i )
|
||||
{
|
||||
double* a1 = a + (0 )*inca + (i )*lda;
|
||||
double* chi1 = x + (i )*incx;
|
||||
double* y1 = y + (0 )*incy;
|
||||
double alpha_chi1;
|
||||
|
||||
bli_dcopycjs( conjx, *chi1, alpha_chi1 );
|
||||
bli_dscals( *alpha, alpha_chi1 );
|
||||
|
||||
f
|
||||
(
|
||||
conja,
|
||||
m,
|
||||
&alpha_chi1,
|
||||
a1, inca,
|
||||
y1, incy,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
// At this point, we know that b_n is exactly equal to the fusing factor.
|
||||
|
||||
a0 = a + 0*lda;
|
||||
a1 = a + 1*lda;
|
||||
|
||||
y0 = y;
|
||||
|
||||
chi0 = *( x + 0*incx );
|
||||
chi1 = *( x + 1*incx );
|
||||
|
||||
|
||||
// Scale each chi scalar by alpha.
|
||||
bli_dscals( *alpha, chi0 );
|
||||
bli_dscals( *alpha, chi1 );
|
||||
|
||||
// Broadcast the (alpha*chi?) scalars to all elements of vector registers.
|
||||
chi0v.v = _mm256_broadcast_sd( &chi0 );
|
||||
chi1v.v = _mm256_broadcast_sd( &chi1 );
|
||||
|
||||
// If there are vectorized iterations, perform them with vector
|
||||
// instructions.
|
||||
if ( inca == 1 && incy == 1 )
|
||||
{
|
||||
for ( i = 0; (i + 15) < m; i += 16 )
|
||||
{
|
||||
// Load the input values.
|
||||
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
|
||||
y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
|
||||
y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
|
||||
|
||||
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
|
||||
a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg );
|
||||
a30v.v = _mm256_loadu_pd( a0 + 3*n_elem_per_reg );
|
||||
|
||||
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
|
||||
a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg );
|
||||
a31v.v = _mm256_loadu_pd( a1 + 3*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v );
|
||||
y3v.v = _mm256_fmadd_pd( a30v.v, chi0v.v, y3v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v );
|
||||
y3v.v = _mm256_fmadd_pd( a31v.v, chi1v.v, y3v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 3*n_elem_per_reg), y3v.v );
|
||||
|
||||
y0 += n_iter_unroll * n_elem_per_reg;
|
||||
a0 += n_iter_unroll * n_elem_per_reg;
|
||||
a1 += n_iter_unroll * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; (i + 11) < m; i += 12 )
|
||||
{
|
||||
// Load the input values.
|
||||
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
|
||||
y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
|
||||
|
||||
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
|
||||
a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg );
|
||||
|
||||
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
|
||||
a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v );
|
||||
|
||||
y0 += 3 * n_elem_per_reg;
|
||||
a0 += 3 * n_elem_per_reg;
|
||||
a1 += 3 * n_elem_per_reg;
|
||||
}
|
||||
for ( ; (i + 7) < m; i += 8 )
|
||||
{
|
||||
// Load the input values.
|
||||
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
|
||||
|
||||
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
|
||||
|
||||
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
|
||||
|
||||
y0 += 2 * n_elem_per_reg;
|
||||
a0 += 2 * n_elem_per_reg;
|
||||
a1 += 2 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; (i + 3) < m; i += 4 )
|
||||
{
|
||||
// Load the input values.
|
||||
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
|
||||
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
|
||||
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
|
||||
|
||||
y0 += n_elem_per_reg;
|
||||
a0 += n_elem_per_reg;
|
||||
a1 += n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; (i + 1) < m; i += 2 )
|
||||
{
|
||||
// Load the input values.
|
||||
y4v.v = _mm_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
|
||||
a40v.v = _mm_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
|
||||
a41v.v = _mm_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y4v.v = _mm_fmadd_pd( a40v.v, chi0v.xmm[0], y4v.v );
|
||||
|
||||
y4v.v = _mm_fmadd_pd( a41v.v, chi1v.xmm[0], y4v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y4v.v );
|
||||
|
||||
y0 += 2;
|
||||
a0 += 2;
|
||||
a1 += 2;
|
||||
}
|
||||
|
||||
// If there are leftover iterations, perform them with scalar code.
|
||||
for ( ; (i + 0) < m ; ++i )
|
||||
{
|
||||
double y0c = *y0;
|
||||
|
||||
const double a0c = *a0;
|
||||
const double a1c = *a1;
|
||||
|
||||
y0c += chi0 * a0c;
|
||||
y0c += chi1 * a1c;
|
||||
|
||||
*y0 = y0c;
|
||||
|
||||
a0 += 1;
|
||||
a1 += 1;
|
||||
y0 += 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for ( i = 0; (i + 0) < m ; ++i )
|
||||
{
|
||||
double y0c = *y0;
|
||||
|
||||
const double a0c = *a0;
|
||||
const double a1c = *a1;
|
||||
|
||||
y0c += chi0 * a0c;
|
||||
y0c += chi1 * a1c;
|
||||
|
||||
*y0 = y0c;
|
||||
|
||||
a0 += inca;
|
||||
a1 += inca;
|
||||
y0 += incy;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
void bli_daxpyf_zen_int_16x4
|
||||
(
|
||||
conj_t conja,
|
||||
conj_t conjx,
|
||||
dim_t m,
|
||||
dim_t b_n,
|
||||
double* restrict alpha,
|
||||
double* restrict a, inc_t inca, inc_t lda,
|
||||
double* restrict x, inc_t incx,
|
||||
double* restrict y, inc_t incy,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
const dim_t fuse_fac = 4;
|
||||
|
||||
const dim_t n_elem_per_reg = 4;
|
||||
const dim_t n_iter_unroll = 4;
|
||||
|
||||
dim_t i;
|
||||
|
||||
double* restrict a0;
|
||||
double* restrict a1;
|
||||
double* restrict a2;
|
||||
double* restrict a3;
|
||||
|
||||
double* restrict y0;
|
||||
|
||||
v4df_t chi0v, chi1v, chi2v, chi3v;
|
||||
|
||||
v4df_t a00v, a01v, a02v, a03v;
|
||||
|
||||
v4df_t a10v, a11v, a12v, a13v;
|
||||
|
||||
v4df_t a20v, a21v, a22v, a23v;
|
||||
|
||||
v4df_t a30v, a31v, a32v, a33v;
|
||||
|
||||
v4df_t y0v, y1v, y2v, y3v;
|
||||
|
||||
double chi0, chi1, chi2, chi3;
|
||||
|
||||
v2df_t y4v;
|
||||
|
||||
v2df_t a40v, a41v, a42v, a43v;
|
||||
|
||||
// If either dimension is zero, or if alpha is zero, return early.
|
||||
if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return;
|
||||
|
||||
// If b_n is not equal to the fusing factor, then perform the entire
|
||||
// operation as a loop over axpyv.
|
||||
if ( b_n != fuse_fac )
|
||||
{
|
||||
#ifdef BLIS_CONFIG_EPYC
|
||||
if(b_n & 2)
|
||||
{
|
||||
bli_daxpyf_zen_int_16x2( conja,
|
||||
conjx,
|
||||
m, 2,
|
||||
alpha, a, inca, lda,
|
||||
x, incx,
|
||||
y, incy,
|
||||
cntx
|
||||
);
|
||||
b_n -= 2;
|
||||
a += 2*lda;
|
||||
x += 2 * incx;
|
||||
}
|
||||
for ( i = 0; i < b_n; ++i )
|
||||
{
|
||||
double* a1 = a + (0 )*inca + (i )*lda;
|
||||
double* chi1 = x + (i )*incx;
|
||||
double* y1 = y + (0 )*incy;
|
||||
double alpha_chi1;
|
||||
|
||||
bli_dcopycjs( conjx, *chi1, alpha_chi1 );
|
||||
bli_dscals( *alpha, alpha_chi1 );
|
||||
|
||||
bli_daxpyv_zen_int10
|
||||
(
|
||||
conja,
|
||||
m,
|
||||
&alpha_chi1,
|
||||
a1, inca,
|
||||
y1, incy,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
|
||||
#else
|
||||
daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx );
|
||||
|
||||
for ( i = 0; i < b_n; ++i )
|
||||
{
|
||||
double* a1 = a + (0 )*inca + (i )*lda;
|
||||
double* chi1 = x + (i )*incx;
|
||||
double* y1 = y + (0 )*incy;
|
||||
double alpha_chi1;
|
||||
|
||||
bli_dcopycjs( conjx, *chi1, alpha_chi1 );
|
||||
bli_dscals( *alpha, alpha_chi1 );
|
||||
|
||||
f
|
||||
(
|
||||
conja,
|
||||
m,
|
||||
&alpha_chi1,
|
||||
a1, inca,
|
||||
y1, incy,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
// At this point, we know that b_n is exactly equal to the fusing factor.
|
||||
|
||||
a0 = a + 0*lda;
|
||||
a1 = a + 1*lda;
|
||||
a2 = a + 2*lda;
|
||||
a3 = a + 3*lda;
|
||||
|
||||
y0 = y;
|
||||
|
||||
chi0 = *( x + 0*incx );
|
||||
chi1 = *( x + 1*incx );
|
||||
chi2 = *( x + 2*incx );
|
||||
chi3 = *( x + 3*incx );
|
||||
|
||||
// Scale each chi scalar by alpha.
|
||||
bli_dscals( *alpha, chi0 );
|
||||
bli_dscals( *alpha, chi1 );
|
||||
bli_dscals( *alpha, chi2 );
|
||||
bli_dscals( *alpha, chi3 );
|
||||
|
||||
// Broadcast the (alpha*chi?) scalars to all elements of vector registers.
|
||||
chi0v.v = _mm256_broadcast_sd( &chi0 );
|
||||
chi1v.v = _mm256_broadcast_sd( &chi1 );
|
||||
chi2v.v = _mm256_broadcast_sd( &chi2 );
|
||||
chi3v.v = _mm256_broadcast_sd( &chi3 );
|
||||
|
||||
// If there are vectorized iterations, perform them with vector
|
||||
// instructions.
|
||||
if ( inca == 1 && incy == 1 )
|
||||
{
|
||||
for ( i = 0; (i + 15) < m; i += 16 )
|
||||
{
|
||||
// Load the input values.
|
||||
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
|
||||
y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
|
||||
y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
|
||||
|
||||
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
|
||||
a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg );
|
||||
a30v.v = _mm256_loadu_pd( a0 + 3*n_elem_per_reg );
|
||||
|
||||
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
|
||||
a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg );
|
||||
a31v.v = _mm256_loadu_pd( a1 + 3*n_elem_per_reg );
|
||||
|
||||
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
|
||||
a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg );
|
||||
a22v.v = _mm256_loadu_pd( a2 + 2*n_elem_per_reg );
|
||||
a32v.v = _mm256_loadu_pd( a2 + 3*n_elem_per_reg );
|
||||
|
||||
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
|
||||
a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg );
|
||||
a23v.v = _mm256_loadu_pd( a3 + 2*n_elem_per_reg );
|
||||
a33v.v = _mm256_loadu_pd( a3 + 3*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v );
|
||||
y3v.v = _mm256_fmadd_pd( a30v.v, chi0v.v, y3v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v );
|
||||
y3v.v = _mm256_fmadd_pd( a31v.v, chi1v.v, y3v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a22v.v, chi2v.v, y2v.v );
|
||||
y3v.v = _mm256_fmadd_pd( a32v.v, chi2v.v, y3v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a23v.v, chi3v.v, y2v.v );
|
||||
y3v.v = _mm256_fmadd_pd( a33v.v, chi3v.v, y3v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 3*n_elem_per_reg), y3v.v );
|
||||
|
||||
y0 += n_iter_unroll * n_elem_per_reg;
|
||||
a0 += n_iter_unroll * n_elem_per_reg;
|
||||
a1 += n_iter_unroll * n_elem_per_reg;
|
||||
a2 += n_iter_unroll * n_elem_per_reg;
|
||||
a3 += n_iter_unroll * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; (i + 11) < m; i += 12 )
|
||||
{
|
||||
// Load the input values.
|
||||
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
|
||||
y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
|
||||
|
||||
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
|
||||
a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg );
|
||||
|
||||
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
|
||||
a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg );
|
||||
|
||||
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
|
||||
a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg );
|
||||
a22v.v = _mm256_loadu_pd( a2 + 2*n_elem_per_reg );
|
||||
|
||||
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
|
||||
a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg );
|
||||
a23v.v = _mm256_loadu_pd( a3 + 2*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a22v.v, chi2v.v, y2v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v );
|
||||
y2v.v = _mm256_fmadd_pd( a23v.v, chi3v.v, y2v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v );
|
||||
|
||||
y0 += 3 * n_elem_per_reg;
|
||||
a0 += 3 * n_elem_per_reg;
|
||||
a1 += 3 * n_elem_per_reg;
|
||||
a2 += 3 * n_elem_per_reg;
|
||||
a3 += 3 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; (i + 7) < m; i += 8 )
|
||||
{
|
||||
// Load the input values.
|
||||
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
|
||||
|
||||
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
|
||||
|
||||
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
|
||||
|
||||
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
|
||||
a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg );
|
||||
|
||||
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
|
||||
a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
|
||||
y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
|
||||
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
|
||||
|
||||
y0 += 2 * n_elem_per_reg;
|
||||
a0 += 2 * n_elem_per_reg;
|
||||
a1 += 2 * n_elem_per_reg;
|
||||
a2 += 2 * n_elem_per_reg;
|
||||
a3 += 2 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
|
||||
for ( ; (i + 3) < m; i += 4)
|
||||
{
|
||||
// Load the input values.
|
||||
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
|
||||
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
|
||||
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
|
||||
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
|
||||
|
||||
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
|
||||
|
||||
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
|
||||
|
||||
y0 += n_elem_per_reg;
|
||||
a0 += n_elem_per_reg;
|
||||
a1 += n_elem_per_reg;
|
||||
a2 += n_elem_per_reg;
|
||||
a3 += n_elem_per_reg;
|
||||
}
|
||||
#if 1
|
||||
for ( ; (i + 1) < m; i += 2)
|
||||
{
|
||||
|
||||
// Load the input values.
|
||||
y4v.v = _mm_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
|
||||
a40v.v = _mm_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
|
||||
a41v.v = _mm_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
|
||||
a42v.v = _mm_loadu_pd( a2 + 0*n_elem_per_reg );
|
||||
|
||||
a43v.v = _mm_loadu_pd( a3 + 0*n_elem_per_reg );
|
||||
|
||||
// perform : y += alpha * x;
|
||||
y4v.v = _mm_fmadd_pd( a40v.v, chi0v.xmm[0], y4v.v );
|
||||
|
||||
y4v.v = _mm_fmadd_pd( a41v.v, chi1v.xmm[0], y4v.v );
|
||||
|
||||
y4v.v = _mm_fmadd_pd( a42v.v, chi2v.xmm[0], y4v.v );
|
||||
|
||||
y4v.v = _mm_fmadd_pd( a43v.v, chi3v.xmm[0], y4v.v );
|
||||
|
||||
// Store the output.
|
||||
_mm_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y4v.v );
|
||||
|
||||
y0 += 2;
|
||||
a0 += 2;
|
||||
a1 += 2;
|
||||
a2 += 2;
|
||||
a3 += 2;
|
||||
}
|
||||
#endif
|
||||
// If there are leftover iterations, perform them with scalar code.
|
||||
for ( ; (i + 0) < m ; ++i )
|
||||
{
|
||||
double y0c = *y0;
|
||||
|
||||
const double a0c = *a0;
|
||||
const double a1c = *a1;
|
||||
const double a2c = *a2;
|
||||
const double a3c = *a3;
|
||||
|
||||
y0c += chi0 * a0c;
|
||||
y0c += chi1 * a1c;
|
||||
y0c += chi2 * a2c;
|
||||
y0c += chi3 * a3c;
|
||||
|
||||
*y0 = y0c;
|
||||
|
||||
a0 += 1;
|
||||
a1 += 1;
|
||||
a2 += 1;
|
||||
a3 += 1;
|
||||
|
||||
y0 += 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for ( i = 0; (i + 0) < m ; ++i )
|
||||
{
|
||||
double y0c = *y0;
|
||||
|
||||
const double a0c = *a0;
|
||||
const double a1c = *a1;
|
||||
const double a2c = *a2;
|
||||
const double a3c = *a3;
|
||||
|
||||
y0c += chi0 * a0c;
|
||||
y0c += chi1 * a1c;
|
||||
y0c += chi2 * a2c;
|
||||
y0c += chi3 * a3c;
|
||||
|
||||
*y0 = y0c;
|
||||
|
||||
a0 += inca;
|
||||
a1 += inca;
|
||||
a2 += inca;
|
||||
a3 += inca;
|
||||
|
||||
y0 += incy;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -95,6 +95,7 @@ SETV_KER_PROT(double, d, setv_zen_int)
|
||||
// axpyf (intrinsics)
|
||||
AXPYF_KER_PROT( float, s, axpyf_zen_int_8 )
|
||||
AXPYF_KER_PROT( double, d, axpyf_zen_int_8 )
|
||||
AXPYF_KER_PROT( double, d, axpyf_zen_int_16x4 )
|
||||
|
||||
AXPYF_KER_PROT( float, s, axpyf_zen_int_5 )
|
||||
AXPYF_KER_PROT( double, d, axpyf_zen_int_5 )
|
||||
|
||||
Reference in New Issue
Block a user