mirror of
https://github.com/amd/blis.git
synced 2026-03-25 11:47:21 +00:00
Merge pull request #516 from nicholaiTukanov/p10-sandbox-rework
P10 sandbox rework
This commit is contained in:
@@ -5,3 +5,4 @@ other
|
||||
temp
|
||||
tmp
|
||||
test
|
||||
p10_testsuite
|
||||
@@ -1,24 +1,20 @@
|
||||
### Low Precision POWER10 Kernels
|
||||
|
||||
This is a special BLIS Sandbox that allows users to call low precision POWER10 `gemm` kernels.
|
||||
This is a special BLIS Sandbox that allows users to call POWER10 reduced precision/integer `GEMM` kernels.
|
||||
|
||||
Supported kernels: `IEEE float16 (bli_shgemm), bfloat16 (bli_sbgemm), int16 (bli_i16gemm), int8 (bli_i8gemm), int4 (bli_i4gemm)`.
|
||||
|
||||
#### Introduction
|
||||
|
||||
This document describes how the low precision POWER10 `gemm` kernels are implemented. The document will also demonstrate how to call the `gemm` kernels.
|
||||
This document describes how the low precision POWER10 `gemm` kernels are implemented and explains how to call the POWER10 `GEMM` kernels.
|
||||
|
||||
**Important: This sandbox does not have the full functionality of BLIS. This sandbox can only perform single threaded, no transpose, GEMM. At this time, full functioning POWER10 hardware has not be released. Once hardware has been released, the kernels will be further optimized in areas such as prefetching and cache blocksizes.**
|
||||
**Important: These kernels does not have the full functionality of BLIS. The kernels can only perform single threaded, no transpose, GEMM.**
|
||||
|
||||
#### Implementation
|
||||
|
||||
The kernels are implemented in `generic_gemm.c`. They are instantiated with macro templates. The main template is called `GENERIC_GEMM`. This template is used to create the 5-loop `gemm` function.
|
||||
The kernels are implemented in `gemm.c`. They are instantiated with macro templates. The main template is called `GENERIC_GEMM`. This template is used to create the 5-loop `gemm` function.
|
||||
|
||||
The API points are created in `gemm_api.c`. In this file, the API points are wrappers for the functions that are created by the templates in `generic_gemm.c`.
|
||||
|
||||
#### Kernels
|
||||
|
||||
The following low precision datatypes have POWER10 `gemm` kernels: `IEEE float16, bfloat16, int16, int8, int4`.
|
||||
|
||||
#### Low Precision Types
|
||||
#### Reduced precision/integer Types
|
||||
|
||||
| BLIS type | BLIS char | Type definition | Used to represent... |
|
||||
|:-----------|:----------|:---------------------------------------|:-------------------------------------|
|
||||
@@ -28,9 +24,9 @@ The following low precision datatypes have POWER10 `gemm` kernels: `IEEE float16
|
||||
| `int8` | `i8` | `int8_t` | 8 bit integers |
|
||||
| `int4` | `i4` | `typedef union{ uint8_t v; struct { uint8_t nib1:4; uint8_t nib2:4; } bits; }` | 4 bit integers |
|
||||
|
||||
#### Low Precision API
|
||||
#### Reduced Precision/Integer API
|
||||
|
||||
The API that is used for the low precision POWER10 `gemm` kernels is similar to the existing [BLIS basic typed API](https://github.com/flame/blis/blob/master/docs/BLISTypedAPI.md). The main difference between the two is that in the existing BLIS typed API, there is only one type for the input and output matrices. However in the low precision API, there is a input and output type.
|
||||
The API that is used for the reduced precision/integer POWER10 `GEMM` kernels is similar to the existing [BLIS basic typed API](https://github.com/flame/blis/blob/master/docs/BLISTypedAPI.md). The main difference is the POWER10 kernels expect two types: `ctype_in` and `ctype_out`.
|
||||
|
||||
Thus the new `gemm` call looks like the following:
|
||||
|
||||
@@ -50,10 +46,7 @@ void bli_??gemm
|
||||
);
|
||||
```
|
||||
|
||||
The first `?` is for the output type. The second `?` is for the input type.
|
||||
|
||||
At this time for IEEE float16 and bfloat16, the only output type is single precision float. For int16, int8, and int4, the only output type is 32 bit int.
|
||||
|
||||
`??` is meant to replaced with the kernel prefix.
|
||||
|
||||
#### How To Build The Sandbox
|
||||
|
||||
@@ -64,6 +57,9 @@ Add the following flags when running the configure script to build BLIS correctl
|
||||
Ensure that you have GCC 10.2 or greater.
|
||||
|
||||
|
||||
#### P10 Testsuite
|
||||
|
||||
In `p10_testsuite`, there are performance gathering and correctness checking programs for the POWER10 reduced precision/integer `GEMM` kernels. By default, the performance gathering and correctness checking is done over square matrices ranging from 80 to 4000 in increments of 80. Performance is measured in GFLOPs, and correctness is measured using the BLIS method (detailed in `blis/testsuite/test_gemm.c`).
|
||||
|
||||
#### References
|
||||
|
||||
|
||||
@@ -36,14 +36,12 @@
|
||||
#define BLIS_SANDBOX_H
|
||||
|
||||
#include "blis.h"
|
||||
#include "gemm_api.h"
|
||||
#include "gemm_prototypes.h"
|
||||
|
||||
// NOTE: This header is the only header required to be present in the sandbox
|
||||
// implementation directory.
|
||||
|
||||
// This header is used to create the typedefs needed for low precision
|
||||
|
||||
// int4 type
|
||||
// int4
|
||||
typedef union
|
||||
{
|
||||
uint8_t v;
|
||||
@@ -54,7 +52,7 @@ typedef union
|
||||
} bits;
|
||||
} nibbles;
|
||||
|
||||
// bfloat16
|
||||
// brain float16
|
||||
typedef union
|
||||
{
|
||||
uint16_t v;
|
||||
@@ -80,36 +78,25 @@ typedef union
|
||||
|
||||
#define P10_PG_SIZE 4096
|
||||
|
||||
// microkernel prototypes
|
||||
GEMM_UKR_PROT2( bfloat16, float, sb, gemm_power10_mma_8x16 )
|
||||
GEMM_UKR_PROT2( float16, float, sh, gemm_power10_mma_8x16 )
|
||||
GEMM_UKR_PROT2( int16_t, int32_t, i16, gemm_power10_mma_8x16 )
|
||||
GEMM_UKR_PROT2( int8_t, int32_t, i8, gemm_power10_mma_8x16 )
|
||||
GEMM_UKR_PROT2( nibbles, int32_t, i4, gemm_power10_mma_8x16 )
|
||||
|
||||
/* Creates a function that initializes a matrix of type ctype with random vals */
|
||||
#define RandomMatrixMacro(ch, ctype, rand_func) \
|
||||
RM_PROT(ch, ctype) \
|
||||
{ \
|
||||
for ( int i=0; i<m; i++ ) \
|
||||
for ( int j=0; j<n; j++ ) \
|
||||
*(ap + j*cs_a + i*rs_a) = \
|
||||
(ctype) rand_func(); \
|
||||
}
|
||||
|
||||
/* Creates a function that initializes a matrix of type ctype with random vals */
|
||||
#define RandomMatrixBounded(ch, ctype, rand_func) \
|
||||
RM_B_PROT(ch, ctype) \
|
||||
{ \
|
||||
for ( int i=0; i<m; i++ ) \
|
||||
for ( int j=0; j<n; j++ ) \
|
||||
*(ap + j*cs_a + i*rs_a) = \
|
||||
(ctype) rand_func() % (upper - lower + 1) + lower; \
|
||||
}
|
||||
|
||||
// gemm kernel prototypes
|
||||
GEMM_FUNC_PROT( float16, float, sh);
|
||||
GEMM_FUNC_PROT( bfloat16, float, sb);
|
||||
GEMM_FUNC_PROT( int16_t, int32_t, i16);
|
||||
GEMM_FUNC_PROT( int8_t, int32_t, i8);
|
||||
GEMM_FUNC_PROT( nibbles, int32_t, i4);
|
||||
|
||||
// pack kernel prototypes
|
||||
PACK_MACRO_PROTO(sb, bfloat16)
|
||||
PACK_MACRO_PROTO(sh, float16)
|
||||
PACK_MACRO_PROTO(i16, int16_t)
|
||||
PACK_MACRO_PROTO(i8, int8_t)
|
||||
PACK_MACRO_PROTO(i4, nibbles)
|
||||
|
||||
#endif
|
||||
|
||||
128
sandbox/power10/gemm.c
Normal file
128
sandbox/power10/gemm.c
Normal file
@@ -0,0 +1,128 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "gemm_template.h"
|
||||
#include "bli_sandbox.h"
|
||||
|
||||
|
||||
GENERIC_GEMM(
|
||||
sb, // kernel name prefix
|
||||
bfloat16, // input type
|
||||
float, // output type
|
||||
(pb/2 + pb%2), // innermost loop iterations
|
||||
sb_pack_a,
|
||||
sb_pack_b, // pack kernel for B
|
||||
bli_sbgemm_power10_mma_8x16, // microkernel function name
|
||||
2, // K_MMA
|
||||
8, // MR
|
||||
16, // NR
|
||||
384, // MC
|
||||
3328, // KC
|
||||
4096, // NC
|
||||
0, // A_ALIGN
|
||||
0 // B_ALIGN
|
||||
);
|
||||
|
||||
GENERIC_GEMM(
|
||||
sh, // kernel name prefix
|
||||
float16, // input type
|
||||
float, // output type
|
||||
(pb/2 + pb%2), // innermost loop iterations
|
||||
sh_pack_a, // pack kernel for A
|
||||
sh_pack_b, // pack kernel for B
|
||||
bli_shgemm_power10_mma_8x16, // microkernel function name
|
||||
2, // K_MMA
|
||||
8, // MR
|
||||
16, // NR
|
||||
384, // MC
|
||||
3328, // KC
|
||||
4096, // NC
|
||||
0, // A_ALIGN
|
||||
0 // B_ALIGN
|
||||
);
|
||||
|
||||
GENERIC_GEMM(
|
||||
i16, // kernel name prefix
|
||||
int16_t, // input type
|
||||
int, // output type
|
||||
(pb/2 + pb%2), // innermost loop iterations
|
||||
i16_pack_a, // pack kernel for A
|
||||
i16_pack_b, // pack kernel for B
|
||||
bli_i16gemm_power10_mma_8x16, // microkernel function name
|
||||
2, // K_MMA
|
||||
8, // MR
|
||||
16, // NR
|
||||
384, // MC
|
||||
3328, // KC
|
||||
4096, // NC
|
||||
0, // A_ALIGN
|
||||
0 // B_ALIGN
|
||||
);
|
||||
|
||||
GENERIC_GEMM(
|
||||
i8, // kernel name prefix
|
||||
int8_t, // input type
|
||||
int, // output type
|
||||
(pb/4 + (pb%4>0)), // innermost loop iterations
|
||||
i8_pack_a, // pack kernel for A
|
||||
i8_pack_b, // pack kernel for B
|
||||
bli_i8gemm_power10_mma_8x16, // microkernel function name
|
||||
4, // K_MMA
|
||||
8, // MR
|
||||
16, // NR
|
||||
384, // MC
|
||||
6656, // KC
|
||||
4096, // NC
|
||||
0, // A_ALIGN
|
||||
0 // B_ALIGN
|
||||
);
|
||||
|
||||
GENERIC_GEMM(
|
||||
i4, // kernel name prefix
|
||||
nibbles, // input type
|
||||
int, // output type
|
||||
(pb/8 + (pb%8>0)), // innermost loop iterations
|
||||
i4_pack_a, // pack kernel for A
|
||||
i4_pack_b, // pack kernel for B
|
||||
bli_i4gemm_power10_mma_8x16, // microkernel function name
|
||||
8, // K_MMA
|
||||
8, // MR
|
||||
16, // NR
|
||||
384, // MC
|
||||
6656, // KC
|
||||
4096, // NC
|
||||
0, // A_ALIGN
|
||||
0 // B_ALIGN
|
||||
);
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
// This file contains the API points for the low precision POWER10 GEMM kernels
|
||||
|
||||
#include "generic_gemm.h"
|
||||
#include "gemm_api.h"
|
||||
|
||||
#define GEMM_FUNC(ch, DTYPE_IN, DTYPE_OUT, A_ALIGNMENT, B_ALIGNMENT, MR, NR, MC, KC, NC) \
|
||||
\
|
||||
void GEMM_FUNC_NAME(ch) \
|
||||
( \
|
||||
trans_t transa, \
|
||||
trans_t transb, \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
DTYPE_OUT* alpha, \
|
||||
DTYPE_IN* a, inc_t rsa, inc_t csa, \
|
||||
DTYPE_IN* b, inc_t rsb, inc_t csb, \
|
||||
DTYPE_OUT* beta, \
|
||||
DTYPE_OUT* c, inc_t rsc, inc_t csc \
|
||||
) \
|
||||
{ \
|
||||
\
|
||||
if (transa != BLIS_NO_TRANSPOSE || transb != BLIS_NO_TRANSPOSE) { \
|
||||
printf("Transpose functionality not implemented yet.\n"); \
|
||||
} \
|
||||
\
|
||||
GEMM_PASTEMAC(ch) \
|
||||
( \
|
||||
MR, NR, MC, KC, NC, \
|
||||
m, n, k, \
|
||||
a, rsa, csa, A_ALIGNMENT, \
|
||||
b, rsb, csb, B_ALIGNMENT, \
|
||||
c, rsc, csc, \
|
||||
alpha, beta \
|
||||
); \
|
||||
} \
|
||||
|
||||
// ch dt_in dt_out MR NR MC KC NC
|
||||
GEMM_FUNC( sb, bfloat16, float, 0, 0, 8, 16, 1664, 1026, 4096);
|
||||
GEMM_FUNC( sh, float16, float, 0, 0, 8, 16, 1664, 1026, 4096);
|
||||
GEMM_FUNC( i16, int16_t, int32_t, 0, 0, 8, 16, 1664, 1026, 4096);
|
||||
GEMM_FUNC( i8, int8_t, int32_t, 0, 0, 8, 16, 1664, 1026, 4096);
|
||||
GEMM_FUNC( i4, nibbles, int32_t, 0, 0, 8, 16, 1664, 1026, 4096);
|
||||
@@ -1,889 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
// Templates for different packing routine
|
||||
|
||||
#include "gemm_pack.h"
|
||||
|
||||
/*
|
||||
|
||||
Details on bit16_dt vector data structure
|
||||
|
||||
Vector X = [ X[0,0] X[0,1] X[1,0] X[1,1] X[2,0] X[2,1] X[3,0] X[3,1] ]
|
||||
Vector Y = [ Y[0,0] Y[0,1] Y[1,0] Y[1,1] Y[2,0] Y[2,1] Y[3,0] Y[3,1] ]
|
||||
|
||||
These bit16_dt vectors represent a 4x2 matrix. Hence, in matrix form it
|
||||
looks like the following:
|
||||
|
||||
X = [ X[0,0] X[0,1]
|
||||
X[1,0] X[1,1]
|
||||
X[2,0] X[2,1]
|
||||
X[3,0] X[3,1] ]
|
||||
|
||||
The outer product instruction: xvbf16ger2 (bfloat16 outer product)
|
||||
|
||||
Syntax:
|
||||
|
||||
xvbf16ger2 ACCUMULATOR A, VECTOR X, VECTOR Y
|
||||
|
||||
Semantics:
|
||||
|
||||
A = X * Y^T
|
||||
|
||||
The generic packing routine would load 8 elements from the same column.
|
||||
This causes an issue since the instruction expects the vector to be a
|
||||
4x2 matrix where the data is packed in contiguous order. Thus, we must make
|
||||
a packing routine that will interleave the matrix data. Making it so
|
||||
that when we load the 8 contiguous elements from A, it will represent
|
||||
a 4x2 section of the matrix.
|
||||
|
||||
*/
|
||||
|
||||
#define k_even_apack_16(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + p_idx*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (p_idx+1)*cs_a ];
|
||||
|
||||
#define k_odd_apack_16(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-1)*cs_a ]; \
|
||||
memset(adest, 0, 2); \
|
||||
adest++;
|
||||
|
||||
#define pad_macro_16(dest_matrix) \
|
||||
memset(dest_matrix, 0, 4); \
|
||||
dest_matrix+=2;
|
||||
|
||||
#define BIT16_PACK_A(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, A) \
|
||||
( \
|
||||
dim_t MR, \
|
||||
int m, int k, \
|
||||
DTYPE_IN* ap, int rs_a, int cs_a, \
|
||||
DTYPE_IN* apack \
|
||||
) \
|
||||
{ \
|
||||
int k_odd = k%2; \
|
||||
int p_idx; \
|
||||
\
|
||||
DTYPE_IN* adest = apack; \
|
||||
for (int i=0; i<m; i+=MR) \
|
||||
{ \
|
||||
int ib = bli_min(MR, m-i); \
|
||||
if (ib == MR) /* Full size column height */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for (int p=0; p<(k/2); p++) \
|
||||
{ \
|
||||
k_even_apack_16(0); \
|
||||
k_even_apack_16(1); \
|
||||
k_even_apack_16(2); \
|
||||
k_even_apack_16(3); \
|
||||
k_even_apack_16(4); \
|
||||
k_even_apack_16(5); \
|
||||
k_even_apack_16(6); \
|
||||
k_even_apack_16(7); \
|
||||
p_idx += 2; \
|
||||
} \
|
||||
\
|
||||
/* In the case that k is odd, we must pad with 0s */ \
|
||||
if(k_odd) \
|
||||
{ \
|
||||
k_odd_apack_16(0); \
|
||||
k_odd_apack_16(1); \
|
||||
k_odd_apack_16(2); \
|
||||
k_odd_apack_16(3); \
|
||||
k_odd_apack_16(4); \
|
||||
k_odd_apack_16(5); \
|
||||
k_odd_apack_16(6); \
|
||||
k_odd_apack_16(7); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else /* Not full size, pad with zeros */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for (int p=0; p<(k/2); p++) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_even_apack_16(ir); \
|
||||
} \
|
||||
for (int ir=ib; ir<MR; ir++) \
|
||||
{ \
|
||||
pad_macro_16(adest); \
|
||||
} \
|
||||
p_idx += 2; \
|
||||
} \
|
||||
\
|
||||
if(k_odd) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_odd_apack_16(ir); \
|
||||
} \
|
||||
for (int ir=ib; ir<MR; ir++) \
|
||||
{ \
|
||||
pad_macro_16(adest); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
#define k_even_bpack_16(jr) \
|
||||
*bdest++ = bp[ p_idx*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (p_idx+1)*rs_b + (j+jr)*cs_b ]; \
|
||||
|
||||
#define k_odd_bpack_16(jr) \
|
||||
*bdest++ = bp[ (k-1)*rs_b + (j+jr)*cs_b ]; \
|
||||
memset(bdest, 0, 2); \
|
||||
bdest++; \
|
||||
|
||||
#define BIT16_PACK_B(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, B) \
|
||||
( \
|
||||
dim_t NR, \
|
||||
int k, int n, \
|
||||
DTYPE_IN* bp, int rs_b, int cs_b, \
|
||||
DTYPE_IN* bpack \
|
||||
) \
|
||||
{ \
|
||||
\
|
||||
int k_odd = k%2; \
|
||||
int p_idx; \
|
||||
\
|
||||
DTYPE_IN* bdest = bpack; \
|
||||
\
|
||||
for( int j=0; j<n; j += NR ) \
|
||||
{ \
|
||||
int jb = bli_min(NR, n-j); \
|
||||
\
|
||||
if ( jb == NR ) /* Full column width micro-panel.*/ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for ( int p=0; p<(k/2); p++ ) \
|
||||
{ \
|
||||
k_even_bpack_16(0); \
|
||||
k_even_bpack_16(1); \
|
||||
k_even_bpack_16(2); \
|
||||
k_even_bpack_16(3); \
|
||||
k_even_bpack_16(4); \
|
||||
k_even_bpack_16(5); \
|
||||
k_even_bpack_16(6); \
|
||||
k_even_bpack_16(7); \
|
||||
k_even_bpack_16(8); \
|
||||
k_even_bpack_16(9); \
|
||||
k_even_bpack_16(10); \
|
||||
k_even_bpack_16(11); \
|
||||
k_even_bpack_16(12); \
|
||||
k_even_bpack_16(13); \
|
||||
k_even_bpack_16(14); \
|
||||
k_even_bpack_16(15); \
|
||||
p_idx += 2; \
|
||||
} \
|
||||
\
|
||||
/* In the case that k is odd, we must pad with 0s */ \
|
||||
if(k_odd) \
|
||||
{ \
|
||||
k_odd_bpack_16(0); \
|
||||
k_odd_bpack_16(1); \
|
||||
k_odd_bpack_16(2); \
|
||||
k_odd_bpack_16(3); \
|
||||
k_odd_bpack_16(4); \
|
||||
k_odd_bpack_16(5); \
|
||||
k_odd_bpack_16(6); \
|
||||
k_odd_bpack_16(7); \
|
||||
k_odd_bpack_16(8); \
|
||||
k_odd_bpack_16(9); \
|
||||
k_odd_bpack_16(10); \
|
||||
k_odd_bpack_16(11); \
|
||||
k_odd_bpack_16(12); \
|
||||
k_odd_bpack_16(13); \
|
||||
k_odd_bpack_16(14); \
|
||||
k_odd_bpack_16(15); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else /* Not a full row size micro-panel. We pad with zeroes. */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for ( int p=0; p<(k/2); p++ ) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_even_bpack_16(jr); \
|
||||
} \
|
||||
for ( int jr=jb; jr<NR; jr++ ) \
|
||||
{ \
|
||||
pad_macro_16(bdest); \
|
||||
} \
|
||||
p_idx += 2; \
|
||||
} \
|
||||
\
|
||||
if(k_odd) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_odd_bpack_16(jr); \
|
||||
} \
|
||||
for ( int jr=jb; jr<NR; jr++ ) \
|
||||
{ \
|
||||
pad_macro_16(bdest); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
|
||||
|
||||
|
||||
/* 8 bit packing routines */
|
||||
|
||||
#define k_even_apack_8(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + p_idx*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (p_idx+1)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (p_idx+2)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (p_idx+3)*cs_a ];
|
||||
|
||||
#define k_left3_apack_8(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-3)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-2)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-1)*cs_a ]; \
|
||||
memset(adest, 0, 1); \
|
||||
adest++;
|
||||
|
||||
#define k_left2_apack_8(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-2)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-1)*cs_a ]; \
|
||||
memset(adest, 0, 2); \
|
||||
adest += 2;
|
||||
|
||||
#define k_left1_apack_8(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-1)*cs_a ]; \
|
||||
memset(adest, 0, 3); \
|
||||
adest += 3;
|
||||
|
||||
#define pad_macro_8(dest_matrix) \
|
||||
memset(dest_matrix, 0, 4); \
|
||||
dest_matrix += 4;
|
||||
|
||||
|
||||
#define BIT8_PACK_A(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, A) \
|
||||
( \
|
||||
dim_t MR, \
|
||||
int m, int k, \
|
||||
DTYPE_IN* ap, int rs_a, int cs_a, \
|
||||
DTYPE_IN* apack \
|
||||
) \
|
||||
{ \
|
||||
int k_left = k%4; \
|
||||
int k_iter = k/4; \
|
||||
int p_idx; \
|
||||
\
|
||||
DTYPE_IN* adest = apack; \
|
||||
\
|
||||
/* Each panel must be packed in this format */ \
|
||||
for (int i=0; i<m; i+=MR) \
|
||||
{ \
|
||||
int ib = bli_min(MR, m-i); \
|
||||
\
|
||||
if (ib == MR) /* Full size column height */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for (int p=0; p<k_iter; p++) \
|
||||
{ \
|
||||
k_even_apack_8(0); \
|
||||
k_even_apack_8(1); \
|
||||
k_even_apack_8(2); \
|
||||
k_even_apack_8(3); \
|
||||
k_even_apack_8(4); \
|
||||
k_even_apack_8(5); \
|
||||
k_even_apack_8(6); \
|
||||
k_even_apack_8(7); \
|
||||
p_idx += 4; \
|
||||
} \
|
||||
\
|
||||
/* In the case that k is odd, we must pad with 0s */ \
|
||||
if(k_left==3) \
|
||||
{ \
|
||||
k_left3_apack_8(0); \
|
||||
k_left3_apack_8(1); \
|
||||
k_left3_apack_8(2); \
|
||||
k_left3_apack_8(3); \
|
||||
k_left3_apack_8(4); \
|
||||
k_left3_apack_8(5); \
|
||||
k_left3_apack_8(6); \
|
||||
k_left3_apack_8(7); \
|
||||
} \
|
||||
else if(k_left==2) \
|
||||
{ \
|
||||
k_left2_apack_8(0); \
|
||||
k_left2_apack_8(1); \
|
||||
k_left2_apack_8(2); \
|
||||
k_left2_apack_8(3); \
|
||||
k_left2_apack_8(4); \
|
||||
k_left2_apack_8(5); \
|
||||
k_left2_apack_8(6); \
|
||||
k_left2_apack_8(7); \
|
||||
} \
|
||||
else if(k_left==1) \
|
||||
{ \
|
||||
k_left1_apack_8(0); \
|
||||
k_left1_apack_8(1); \
|
||||
k_left1_apack_8(2); \
|
||||
k_left1_apack_8(3); \
|
||||
k_left1_apack_8(4); \
|
||||
k_left1_apack_8(5); \
|
||||
k_left1_apack_8(6); \
|
||||
k_left1_apack_8(7); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else /* Not full size, pad with zeros */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for (int p=0; p<k_iter; p++) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_even_apack_8(ir); \
|
||||
} \
|
||||
for (int ir=ib; ir<MR; ir++) \
|
||||
{ \
|
||||
pad_macro_8(adest); \
|
||||
} \
|
||||
p_idx += 4; \
|
||||
} \
|
||||
\
|
||||
if(k_left==3) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_left3_apack_8(ir); \
|
||||
} \
|
||||
} \
|
||||
else if(k_left==2) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_left2_apack_8(ir); \
|
||||
} \
|
||||
} \
|
||||
else if(k_left==1) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_left1_apack_8(ir); \
|
||||
} \
|
||||
} \
|
||||
if(k_left!=0) \
|
||||
{ \
|
||||
for (int ir=ib; ir<MR; ir++) { \
|
||||
pad_macro_8(adest); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
#define k_even_bpack_8(jr) \
|
||||
*bdest++ = bp[ p_idx*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (p_idx+1)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (p_idx+2)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (p_idx+3)*rs_b + (j+jr)*cs_b ];
|
||||
|
||||
#define k_left3_bpack_8(jr) \
|
||||
*bdest++ = bp[ (k-3)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (k-2)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (k-1)*rs_b + (j+jr)*cs_b ]; \
|
||||
memset(bdest, 0, 1); \
|
||||
bdest++;
|
||||
|
||||
#define k_left2_bpack_8(jr) \
|
||||
*bdest++ = bp[ (k-2)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (k-1)*rs_b + (j+jr)*cs_b ]; \
|
||||
memset(bdest, 0, 2); \
|
||||
bdest+=2;
|
||||
|
||||
#define k_left1_bpack_8(jr) \
|
||||
*bdest++ = bp[ (k-1)*rs_b + (j+jr)*cs_b ]; \
|
||||
memset(bdest, 0, 3); \
|
||||
bdest+=3;
|
||||
|
||||
|
||||
#define BIT8_PACK_B(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, B) \
|
||||
( \
|
||||
dim_t NR, \
|
||||
int k, int n, \
|
||||
DTYPE_IN* bp, int rs_b, int cs_b, \
|
||||
DTYPE_IN* bpack \
|
||||
) \
|
||||
{ \
|
||||
int k_left = k%4; \
|
||||
int k_iter = k/4; \
|
||||
int p_idx; \
|
||||
\
|
||||
DTYPE_IN* bdest = bpack; \
|
||||
\
|
||||
for( int j=0; j<n; j += NR ) \
|
||||
{ \
|
||||
int jb = bli_min(NR, n-j); \
|
||||
\
|
||||
if ( jb == NR ) /* Full column width micro-panel.*/ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for ( int p=0; p<k_iter; p++ ) \
|
||||
{ \
|
||||
k_even_bpack_8(0); \
|
||||
k_even_bpack_8(1); \
|
||||
k_even_bpack_8(2); \
|
||||
k_even_bpack_8(3); \
|
||||
k_even_bpack_8(4); \
|
||||
k_even_bpack_8(5); \
|
||||
k_even_bpack_8(6); \
|
||||
k_even_bpack_8(7); \
|
||||
k_even_bpack_8(8); \
|
||||
k_even_bpack_8(9); \
|
||||
k_even_bpack_8(10); \
|
||||
k_even_bpack_8(11); \
|
||||
k_even_bpack_8(12); \
|
||||
k_even_bpack_8(13); \
|
||||
k_even_bpack_8(14); \
|
||||
k_even_bpack_8(15); \
|
||||
p_idx += 4; \
|
||||
} \
|
||||
\
|
||||
if(k_left==3) \
|
||||
{ \
|
||||
k_left3_bpack_8(0); \
|
||||
k_left3_bpack_8(1); \
|
||||
k_left3_bpack_8(2); \
|
||||
k_left3_bpack_8(3); \
|
||||
k_left3_bpack_8(4); \
|
||||
k_left3_bpack_8(5); \
|
||||
k_left3_bpack_8(6); \
|
||||
k_left3_bpack_8(7); \
|
||||
k_left3_bpack_8(8); \
|
||||
k_left3_bpack_8(9); \
|
||||
k_left3_bpack_8(10); \
|
||||
k_left3_bpack_8(11); \
|
||||
k_left3_bpack_8(12); \
|
||||
k_left3_bpack_8(13); \
|
||||
k_left3_bpack_8(14); \
|
||||
k_left3_bpack_8(15); \
|
||||
} \
|
||||
else if(k_left==2) \
|
||||
{ \
|
||||
k_left2_bpack_8(0); \
|
||||
k_left2_bpack_8(1); \
|
||||
k_left2_bpack_8(2); \
|
||||
k_left2_bpack_8(3); \
|
||||
k_left2_bpack_8(4); \
|
||||
k_left2_bpack_8(5); \
|
||||
k_left2_bpack_8(6); \
|
||||
k_left2_bpack_8(7); \
|
||||
k_left2_bpack_8(8); \
|
||||
k_left2_bpack_8(9); \
|
||||
k_left2_bpack_8(10); \
|
||||
k_left2_bpack_8(11); \
|
||||
k_left2_bpack_8(12); \
|
||||
k_left2_bpack_8(13); \
|
||||
k_left2_bpack_8(14); \
|
||||
k_left2_bpack_8(15); \
|
||||
} \
|
||||
else if(k_left==1) \
|
||||
{ \
|
||||
k_left1_bpack_8(0); \
|
||||
k_left1_bpack_8(1); \
|
||||
k_left1_bpack_8(2); \
|
||||
k_left1_bpack_8(3); \
|
||||
k_left1_bpack_8(4); \
|
||||
k_left1_bpack_8(5); \
|
||||
k_left1_bpack_8(6); \
|
||||
k_left1_bpack_8(7); \
|
||||
k_left1_bpack_8(8); \
|
||||
k_left1_bpack_8(9); \
|
||||
k_left1_bpack_8(10); \
|
||||
k_left1_bpack_8(11); \
|
||||
k_left1_bpack_8(12); \
|
||||
k_left1_bpack_8(13); \
|
||||
k_left1_bpack_8(14); \
|
||||
k_left1_bpack_8(15); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else /* Not a full row size micro-panel. We pad with zeroes. */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for ( int p=0; p<k_iter; p++ ) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_even_bpack_8(jr); \
|
||||
} \
|
||||
for ( int jr=jb; jr<NR; jr++ ) \
|
||||
{ \
|
||||
pad_macro_8(bdest); \
|
||||
} \
|
||||
p_idx += 4; \
|
||||
} \
|
||||
\
|
||||
if(k_left==3) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_left3_bpack_8(jr); \
|
||||
} \
|
||||
} \
|
||||
else if(k_left==2) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_left2_bpack_8(jr); \
|
||||
} \
|
||||
} \
|
||||
else if(k_left==1) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_left1_bpack_8(jr); \
|
||||
} \
|
||||
} \
|
||||
if (k_left!=0) \
|
||||
{ \
|
||||
for ( int jr=jb; jr<NR; jr++ ) { \
|
||||
pad_macro_8(bdest); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/* Packing Routines */
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
|
||||
Memory is byte-addressed. This results in two options when dealing with
|
||||
int4. Either store 1 int4 value in a byte, or store 2 int4 values in 1
|
||||
byte. The former is wasteful in storage, but it makes for a simpler
|
||||
packing routine. However, we want to not waste any storage if possible.
|
||||
Therefore I went with the latter when designing my int4 kernel.
|
||||
|
||||
The int4 outerproduct instruction expects a 4x8 matrix in row major order
|
||||
to be loaded into the vector. In order to achieve this 4x8 row major
|
||||
matrix, we pack as many 4x8 panels from the src matrix into the pack matrix.
|
||||
|
||||
To illustrate how my packing routine works:
|
||||
|
||||
x0 x1 x2 x3 x4 x5 x6 x7
|
||||
x9 x10 x11 x12 x13 x14 x15 x16
|
||||
x17 x18 x19 x20 x21 x22 x23 x24
|
||||
x25 x26 x27 x28 x29 x30 x31 x32
|
||||
|
||||
Assume we have a 4x8 matrix that is stored in column major order. Also
|
||||
since we are dealing with int4 values, the values are stored as pairs
|
||||
within a union struct. i.e. (x0, x9) are stored together in the same struct.
|
||||
|
||||
Therefore in order to get the desired 4x8 row major matrix, we must go
|
||||
through the first row of structs and grab the first int4 value and insert
|
||||
it into the appropriate spot in the pack matrix. This means that after
|
||||
packing, (x0, x1) will be stored together in the same struct.
|
||||
|
||||
This process then repeats until the entire src matrix is packed in these
|
||||
4x8 row major matrix panels.
|
||||
|
||||
To handle edge cases, the packing routine will fill in zeros where it is
|
||||
appropriate.
|
||||
|
||||
*/
|
||||
|
||||
#include "i4_macros.h"
|
||||
|
||||
#define BIT4_PACK_A(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, A) \
|
||||
( \
|
||||
dim_t MR, \
|
||||
int m, int k, \
|
||||
DTYPE_IN* ap, int rs_a, int cs_a, \
|
||||
DTYPE_IN* apack \
|
||||
) \
|
||||
{ \
|
||||
int p_idx, k_left, k_iter; \
|
||||
DTYPE_IN* adest = apack; \
|
||||
\
|
||||
k_left = k%8; \
|
||||
k_iter = k/8; \
|
||||
\
|
||||
int i = 0; /* i is used for byte addressing */ \
|
||||
for(int int4_i=0; int4_i<m; int4_i+=MR) { /* pack panels */ \
|
||||
\
|
||||
int ib = bli_min(MR, m-int4_i); \
|
||||
p_idx = 0; \
|
||||
\
|
||||
if (ib == MR) { /* full size */ \
|
||||
for (int p=0; p<k_iter; p++) { \
|
||||
col_m_order_1(adest, ap, (i+0), rs_a, cs_a); \
|
||||
col_m_order_2(adest, ap, (i+0), rs_a, cs_a); \
|
||||
col_m_order_1(adest, ap, (i+1), rs_a, cs_a); \
|
||||
col_m_order_2(adest, ap, (i+1), rs_a, cs_a); \
|
||||
col_m_order_1(adest, ap, (i+2), rs_a, cs_a); \
|
||||
col_m_order_2(adest, ap, (i+2), rs_a, cs_a); \
|
||||
col_m_order_1(adest, ap, (i+3), rs_a, cs_a); \
|
||||
col_m_order_2(adest, ap, (i+3), rs_a, cs_a); \
|
||||
p_idx += 8; \
|
||||
} \
|
||||
\
|
||||
/* handle edge cases if there are any */ \
|
||||
if(k_left == 7) { \
|
||||
apad_col_kleft7(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 6) { \
|
||||
apad_col_kleft6(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 5) { \
|
||||
apad_col_kleft5(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 4) { \
|
||||
apad_col_kleft4(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 3) { \
|
||||
apad_col_kleft3(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 2) { \
|
||||
apad_col_kleft2(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 1) { \
|
||||
apad_col_kleft1(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else { /* not full size */ \
|
||||
for (int p=0; p<k_iter; p++) { \
|
||||
for (int ir=0; ir<ib; ir++) { \
|
||||
if (ir%2==0) { \
|
||||
col_m_order_1(adest, ap, (i+ir/2), rs_a, cs_a); \
|
||||
} \
|
||||
else { \
|
||||
col_m_order_2(adest, ap, (i+ir/2), rs_a, cs_a); \
|
||||
} \
|
||||
} \
|
||||
for (int ir=ib; ir<MR; ir++) { \
|
||||
zero_out_dest(adest); \
|
||||
} \
|
||||
p_idx += 8; \
|
||||
} \
|
||||
\
|
||||
/* handle edge cases if there are any */ \
|
||||
if(k_left == 7) { \
|
||||
edge7(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 6) { \
|
||||
edge6(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 5) { \
|
||||
edge5(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 4) { \
|
||||
edge4(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 3) { \
|
||||
edge3(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 2) { \
|
||||
edge2(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 1) { \
|
||||
edge1(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
\
|
||||
/* fill in zeros when an edge case occurs */ \
|
||||
if(k_left!=0) \
|
||||
{ \
|
||||
for (int ir=ib; ir<MR; ir++) \
|
||||
zero_out_dest(adest); \
|
||||
} \
|
||||
} \
|
||||
i += (MR/2); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
#define BIT4_PACK_B(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, B) \
|
||||
( \
|
||||
dim_t NR, \
|
||||
int k, int n, \
|
||||
DTYPE_IN* bp, int rs_b, int cs_b, \
|
||||
DTYPE_IN* bpack \
|
||||
) \
|
||||
{ \
|
||||
\
|
||||
int p_idx, k_left, k_iter; \
|
||||
DTYPE_IN* bdest = bpack; \
|
||||
\
|
||||
k_left = k%8; \
|
||||
k_iter = k/8; \
|
||||
\
|
||||
int j = 0; \
|
||||
\
|
||||
for(int int4_j=0; int4_j<n; int4_j+=NR) { /* pack panels */ \
|
||||
int jb = bli_min(NR, n-int4_j); \
|
||||
\
|
||||
p_idx = 0; \
|
||||
if (jb == NR) { /* full size */ \
|
||||
for (int p=0; p<k_iter; p++) { \
|
||||
col_m_order_1(bdest, bp, (j+0), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+0), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+1), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+1), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+2), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+2), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+3), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+3), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+4), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+4), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+5), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+5), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+6), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+6), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+7), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+7), cs_b, rs_b); \
|
||||
p_idx += 8; \
|
||||
} \
|
||||
\
|
||||
/* handle edge cases if there are any */ \
|
||||
if(k_left == 7) { \
|
||||
bpad_col_kleft7(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 6) { \
|
||||
bpad_col_kleft6(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 5) { \
|
||||
bpad_col_kleft5(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 4) { \
|
||||
bpad_col_kleft4(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 3) { \
|
||||
bpad_col_kleft3(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 2) { \
|
||||
bpad_col_kleft2(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 1) { \
|
||||
bpad_col_kleft1(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
} \
|
||||
else { /* not full size */ \
|
||||
for (int p=0; p<k_iter; p++) { \
|
||||
for (int jr=0; jr<jb; jr++) { \
|
||||
if (jr%2==0) { \
|
||||
col_m_order_1(bdest, bp, (j+jr/2), cs_b, rs_b); \
|
||||
} \
|
||||
else { \
|
||||
col_m_order_2(bdest, bp, (j+jr/2), cs_b, rs_b); \
|
||||
} \
|
||||
} \
|
||||
for (int jr=jb; jr<NR; jr++) { \
|
||||
zero_out_dest(bdest); \
|
||||
} \
|
||||
p_idx += 8; \
|
||||
} \
|
||||
\
|
||||
/* handle edge cases if there are any */ \
|
||||
if(k_left == 7) { \
|
||||
edge7(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 6) { \
|
||||
edge6(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 5) { \
|
||||
edge5(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 4) { \
|
||||
edge4(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 3) { \
|
||||
edge3(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 2) { \
|
||||
edge2(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 1) { \
|
||||
edge1(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
\
|
||||
/* fill in zeros when an edge case occurs */ \
|
||||
if(k_left!=0) \
|
||||
{ \
|
||||
for (int ir=jb; ir<NR; ir++) \
|
||||
zero_out_dest(bdest); \
|
||||
} \
|
||||
} \
|
||||
j += (NR/2); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
|
||||
#define BIT16_PACK_ROUTINES(ch, DTYPE_IN) \
|
||||
BIT16_PACK_A(ch, DTYPE_IN); \
|
||||
BIT16_PACK_B(ch, DTYPE_IN);
|
||||
|
||||
#define BIT8_PACK_ROUTINES(ch, DTYPE_IN) \
|
||||
BIT8_PACK_A(ch, DTYPE_IN); \
|
||||
BIT8_PACK_B(ch, DTYPE_IN);
|
||||
|
||||
#define BIT4_PACK_ROUTINES(ch, DTYPE_IN) \
|
||||
BIT4_PACK_A(ch, DTYPE_IN); \
|
||||
BIT4_PACK_B(ch, DTYPE_IN);
|
||||
|
||||
BIT16_PACK_ROUTINES(sb, bfloat16);
|
||||
BIT16_PACK_ROUTINES(i16, int16_t);
|
||||
BIT16_PACK_ROUTINES(sh, float16);
|
||||
|
||||
BIT8_PACK_ROUTINES(i8, int8_t);
|
||||
|
||||
BIT4_PACK_ROUTINES(i4, nibbles);
|
||||
@@ -1,64 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
// Templates for packing routines prototypes
|
||||
|
||||
#include "bli_sandbox.h"
|
||||
|
||||
#define PACK_FUNC_NAME_(ch, mat) ch ## _pack ## mat
|
||||
#define PACK_FUNC_NAME(ch, mat) PACK_FUNC_NAME_(ch, mat)
|
||||
|
||||
#define PACK_MACRO_PROTO(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, A) \
|
||||
( \
|
||||
dim_t MR, \
|
||||
int m, int k, \
|
||||
DTYPE_IN* ap, int rs_a, int cs_a, \
|
||||
DTYPE_IN* apack \
|
||||
); \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, B) \
|
||||
( \
|
||||
dim_t NR, \
|
||||
int k, int n, \
|
||||
DTYPE_IN* bp, int rs_b, int cs_b, \
|
||||
DTYPE_IN* bpack \
|
||||
);
|
||||
|
||||
PACK_MACRO_PROTO(sb, bfloat16)
|
||||
PACK_MACRO_PROTO(sh, float16)
|
||||
PACK_MACRO_PROTO(i16, int16_t)
|
||||
PACK_MACRO_PROTO(i8, int8_t)
|
||||
PACK_MACRO_PROTO(i4, nibbles)
|
||||
@@ -32,11 +32,11 @@
|
||||
|
||||
*/
|
||||
|
||||
// Prototypes and template for the low precision POWER10 GEMM API
|
||||
|
||||
// BLIS GEMM function naming scheme
|
||||
#define GEMM_FUNC_NAME_(ch) bli_ ## ch ## gemm
|
||||
#define GEMM_FUNC_NAME(ch) GEMM_FUNC_NAME_(ch)
|
||||
|
||||
// BLIS GEMM function prototype macro
|
||||
#define GEMM_FUNC_PROT(DTYPE_IN, DTYPE_OUT, ch) \
|
||||
void GEMM_FUNC_NAME(ch) \
|
||||
( \
|
||||
@@ -51,3 +51,26 @@
|
||||
DTYPE_OUT* beta, \
|
||||
DTYPE_OUT* c, inc_t rsc, inc_t csc \
|
||||
)
|
||||
|
||||
// Pack routine naming scheme
|
||||
#define PACK_FUNC_NAME_(ch, mat) ch ## _pack_ ## mat
|
||||
#define PACK_FUNC_NAME(ch, mat) PACK_FUNC_NAME_(ch, mat)
|
||||
|
||||
// Pack routine prototype
|
||||
#define PACK_MACRO_PROTO(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, a) \
|
||||
( \
|
||||
dim_t MR, \
|
||||
int m, int k, \
|
||||
DTYPE_IN* ap, int rs_a, int cs_a, \
|
||||
DTYPE_IN* apack \
|
||||
); \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, b) \
|
||||
( \
|
||||
dim_t NR, \
|
||||
int k, int n, \
|
||||
DTYPE_IN* bp, int rs_b, int cs_b, \
|
||||
DTYPE_IN* bpack \
|
||||
);
|
||||
@@ -32,66 +32,96 @@
|
||||
|
||||
*/
|
||||
|
||||
// Using the GENERIC_GEMM template, create GEMM functions for each datatype
|
||||
#include "blis.h"
|
||||
|
||||
#include "generic_gemm.h"
|
||||
#include "gemm_pack.h"
|
||||
/*
|
||||
Macro function template for creating BLIS GEMM kernels using the Goto method.
|
||||
|
||||
#define GENERIC_GEMM(ch, DTYPE_IN, DTYPE_OUT, NEW_PB, MULT, UK_FUNC) \
|
||||
This GEMM template assumes that the matrices are both not transposed.
|
||||
|
||||
ch - kernel name prefix
|
||||
DTYPE_IN, DTYPE_OUT - datatypes of the input and output operands respectively
|
||||
NEW_PB - number of iterations of the innermost loop
|
||||
PACK_A, PACK_B - pack kernels names
|
||||
MICROKERNEL - microkernel function name
|
||||
K_MMA - number of outer products performed by an instruction
|
||||
MR, NR, MC, KC, NC - Cache blocking parameters
|
||||
B_ALIGN, A_ALIGN - Extra byte alignment for the pack matrix buffers
|
||||
*/
|
||||
#define GENERIC_GEMM( \
|
||||
ch, \
|
||||
DTYPE_IN, \
|
||||
DTYPE_OUT, \
|
||||
NEW_PB, \
|
||||
PACK_A, \
|
||||
PACK_B, \
|
||||
MICROKERNEL, \
|
||||
K_MMA, \
|
||||
MR, \
|
||||
NR, \
|
||||
MC, \
|
||||
KC, \
|
||||
NC, \
|
||||
B_ALIGN, \
|
||||
A_ALIGN \
|
||||
) \
|
||||
\
|
||||
void GEMM_PASTEMAC(ch) \
|
||||
void GEMM_FUNC_NAME(ch) \
|
||||
( \
|
||||
dim_t MR, dim_t NR, dim_t KC, dim_t NC, dim_t MC, \
|
||||
int m, int n, int k, \
|
||||
DTYPE_IN* restrict A, int rs_a, int cs_a, int A_align, \
|
||||
DTYPE_IN* restrict B, int rs_b, int cs_b, int B_align, \
|
||||
DTYPE_OUT* restrict C, int rs_c, int cs_c, \
|
||||
DTYPE_OUT* alpha, DTYPE_OUT* beta \
|
||||
trans_t transa, \
|
||||
trans_t transb, \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
DTYPE_OUT* alpha, \
|
||||
DTYPE_IN* a, inc_t rsa, inc_t csa, \
|
||||
DTYPE_IN* b, inc_t rsb, inc_t csb, \
|
||||
DTYPE_OUT* beta, \
|
||||
DTYPE_OUT* c, inc_t rsc, inc_t csc \
|
||||
) \
|
||||
{ \
|
||||
DTYPE_OUT zero = 0.0; \
|
||||
DTYPE_OUT beta_ = *beta; \
|
||||
\
|
||||
DTYPE_IN * restrict btilde_sys = ( DTYPE_IN *) aligned_alloc( P10_PG_SIZE, B_align + KC * NC * sizeof( DTYPE_IN ) ); \
|
||||
DTYPE_IN * restrict atilde_sys = ( DTYPE_IN *) aligned_alloc( P10_PG_SIZE, A_align + MC * KC * sizeof( DTYPE_IN ) ); \
|
||||
DTYPE_IN * restrict btilde_sys = ( DTYPE_IN *) aligned_alloc( P10_PG_SIZE, B_ALIGN + KC * NC * sizeof( DTYPE_IN ) ); \
|
||||
DTYPE_IN * restrict atilde_sys = ( DTYPE_IN *) aligned_alloc( P10_PG_SIZE, A_ALIGN + MC * KC * sizeof( DTYPE_IN ) ); \
|
||||
\
|
||||
DTYPE_IN * restrict btilde_usr = ( DTYPE_IN *)((char *)btilde_sys + B_align); \
|
||||
DTYPE_IN * restrict atilde_usr = ( DTYPE_IN *)((char *)atilde_sys + A_align); \
|
||||
DTYPE_IN * restrict btilde_usr = ( DTYPE_IN *)((char *)btilde_sys + B_ALIGN); \
|
||||
DTYPE_IN * restrict atilde_usr = ( DTYPE_IN *)((char *)atilde_sys + A_ALIGN); \
|
||||
\
|
||||
const int rstep_c = MC*rs_c; \
|
||||
const int cstep_c = NC*cs_c; \
|
||||
const int rstep_c = MC * rsc; \
|
||||
const int cstep_c = NC * csc; \
|
||||
\
|
||||
const int rstep_a = MC*rs_a; \
|
||||
const int cstep_a = KC*cs_a; \
|
||||
const int rstep_a = MC * rsa; \
|
||||
const int cstep_a = KC * csa; \
|
||||
\
|
||||
const int rstep_b = KC*rs_b; \
|
||||
const int cstep_b = NC*cs_b; \
|
||||
const int rstep_b = KC * rsb; \
|
||||
const int cstep_b = NC * csb; \
|
||||
\
|
||||
const int rstep_mt_c = MR*rs_c; \
|
||||
const int cstep_mt_c = NR*cs_c; \
|
||||
const int rstep_mt_c = MR * rsc; \
|
||||
const int cstep_mt_c = NR * csc; \
|
||||
\
|
||||
DTYPE_OUT * restrict cblock = C; \
|
||||
DTYPE_IN * restrict bblock = B; \
|
||||
DTYPE_OUT * restrict cblock = c; \
|
||||
DTYPE_IN * restrict bblock = b; \
|
||||
\
|
||||
DTYPE_OUT tmp_cmicrotile[MR*NR]; \
|
||||
int rs_ct = ( rs_c == 1 ? 1 : NR ); \
|
||||
int cs_ct = ( rs_c == 1 ? MR : 1 ); \
|
||||
int rsct = ( rsc == 1 ? 1 : NR ); \
|
||||
int csct = ( rsc == 1 ? MR : 1 ); \
|
||||
\
|
||||
for ( int jc=0; jc<n; jc+=NC ) \
|
||||
{ \
|
||||
int jb = bli_min( NC, n-jc ); \
|
||||
DTYPE_IN * restrict apanel = A; \
|
||||
DTYPE_IN * restrict apanel = a; \
|
||||
DTYPE_IN * restrict bpanel = bblock; \
|
||||
\
|
||||
for ( int pc=0; pc<k; pc+=KC ) \
|
||||
{ \
|
||||
int pb = bli_min( KC, k-pc ); \
|
||||
ch ## _packB \
|
||||
(NR, pb, jb, bpanel, rs_b, cs_b, btilde_usr); \
|
||||
PACK_B (NR, pb, jb, bpanel, rsb, csb, btilde_usr); \
|
||||
\
|
||||
int new_pb = NEW_PB; \
|
||||
const int a_ps = new_pb * (MULT * MR); \
|
||||
const int b_ps = new_pb * (MULT * NR); \
|
||||
const int a_ps = new_pb * (K_MMA * MR); \
|
||||
const int b_ps = new_pb * (K_MMA * NR); \
|
||||
\
|
||||
DTYPE_OUT * restrict cpanel = cblock; \
|
||||
DTYPE_IN * restrict ablock = apanel; \
|
||||
@@ -100,8 +130,8 @@ void GEMM_PASTEMAC(ch) \
|
||||
{ \
|
||||
int ib = bli_min( MC, m-ic ); \
|
||||
\
|
||||
ch ## _packA \
|
||||
( MR, ib, pb, ablock, rs_a, cs_a, atilde_usr ); \
|
||||
/* pack_a (ib, pb, (uint32_t *) ablock, rsa, csa, (uint32_t *) atilde_usr ); */ \
|
||||
PACK_A (MR, ib, pb, ablock, rsa, csa, atilde_usr ); \
|
||||
\
|
||||
DTYPE_OUT * restrict cmicrotile_col = cpanel; \
|
||||
DTYPE_IN * restrict bmicropanel = btilde_usr; \
|
||||
@@ -117,16 +147,16 @@ void GEMM_PASTEMAC(ch) \
|
||||
int irb = bli_min( MR, ib-ir ); \
|
||||
\
|
||||
if (jrb == NR && irb == MR) \
|
||||
UK_FUNC (new_pb, alpha, amicropanel, bmicropanel, beta, cmicrotile, rs_c, cs_c, NULL, NULL); \
|
||||
MICROKERNEL (new_pb, alpha, amicropanel, bmicropanel, beta, cmicrotile, rsc, csc, NULL, NULL); \
|
||||
else \
|
||||
{ \
|
||||
UK_FUNC (new_pb, alpha, amicropanel, bmicropanel, &zero, tmp_cmicrotile, rs_ct, cs_ct, NULL, NULL); \
|
||||
MICROKERNEL (new_pb, alpha, amicropanel, bmicropanel, &zero, tmp_cmicrotile, rsct, csct, NULL, NULL); \
|
||||
\
|
||||
for (int j=0; j<jrb;j++) \
|
||||
for (int i=0; i<irb;i++) \
|
||||
cmicrotile[i*rs_c + j*cs_c] = \
|
||||
beta_ * cmicrotile[i*rs_c + j*cs_c] + \
|
||||
tmp_cmicrotile[i*rs_ct + j*cs_ct]; \
|
||||
cmicrotile[i*rsc + j*csc] = \
|
||||
beta_ * cmicrotile[i*rsc + j*csc] + \
|
||||
tmp_cmicrotile[i*rsct + j*csct]; \
|
||||
} \
|
||||
amicropanel += a_ps; \
|
||||
cmicrotile += rstep_mt_c; \
|
||||
@@ -143,12 +173,8 @@ void GEMM_PASTEMAC(ch) \
|
||||
cblock += cstep_c; \
|
||||
bblock += cstep_b; \
|
||||
} \
|
||||
\
|
||||
free(btilde_sys); \
|
||||
free(atilde_sys); \
|
||||
}
|
||||
|
||||
GENERIC_GEMM( sb, bfloat16, float, (pb/2 + pb%2), 2, bli_sbgemm_power10_mma_8x16);
|
||||
GENERIC_GEMM(i16, int16_t, int, (pb/2 + pb%2), 2, bli_i16gemm_power10_mma_8x16);
|
||||
GENERIC_GEMM( sh, float16, float, (pb/2 + pb%2), 2, bli_shgemm_power10_mma_8x16);
|
||||
GENERIC_GEMM( i8, int8_t, int, (pb/4 + (pb%4>0)), 4, bli_i8gemm_power10_mma_8x16);
|
||||
GENERIC_GEMM( i4, nibbles, int, (pb/8 + (pb%8>0)), 8, bli_i4gemm_power10_mma_8x16);
|
||||
31
sandbox/power10/p10_testsuite/Makefile
Normal file
31
sandbox/power10/p10_testsuite/Makefile
Normal file
@@ -0,0 +1,31 @@
|
||||
BLIS_PATH := ../../..
|
||||
|
||||
BLIS_INC := $(BLIS_PATH)/include/power10
|
||||
BLIS_LIB := $(BLIS_PATH)/lib/power10/libblis.a
|
||||
|
||||
CC := gcc
|
||||
LINKER := $(CC)
|
||||
|
||||
CFLAGS := -I $(BLIS_INC)
|
||||
LDFLAGS := -lpthread -lm
|
||||
|
||||
OBJS := $(patsubst %.c,%.o, $(wildcard *.c))
|
||||
PERF_OBJS := performance.o
|
||||
COR_OBJS := correctness.o cast_funcs.o
|
||||
|
||||
all: performance correctness
|
||||
|
||||
$(OBJS): %.o: %.c
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
performance: $(PERF_OBJS)
|
||||
$(LINKER) $(PERF_OBJS) $(BLIS_LIB) -o ./gather_perf.x $(LDFLAGS)
|
||||
|
||||
correctness: $(COR_OBJS)
|
||||
$(LINKER) $(COR_OBJS) $(BLIS_LIB) -o ./test_correctness.x $(LDFLAGS)
|
||||
|
||||
csv_clean:
|
||||
rm -rf *.csv
|
||||
|
||||
clean:
|
||||
rm -rf *.x *.o
|
||||
180
sandbox/power10/p10_testsuite/cast_funcs.c
Normal file
180
sandbox/power10/p10_testsuite/cast_funcs.c
Normal file
@@ -0,0 +1,180 @@
|
||||
#include "cast_funcs.h"
|
||||
#include "../bli_sandbox.h"
|
||||
|
||||
// bit map used for casting float to bfloat16
|
||||
typedef union
|
||||
{
|
||||
float v;
|
||||
struct
|
||||
{
|
||||
uint32_t m:23;
|
||||
uint32_t e:8;
|
||||
uint32_t s:1;
|
||||
} bits;
|
||||
} float32_s;
|
||||
|
||||
|
||||
// cast float16 into float
|
||||
float cast_f16_to_f32(float16 val)
|
||||
{
|
||||
uint16_t in = val.v;
|
||||
float out;
|
||||
uint32_t t1;
|
||||
uint32_t t2;
|
||||
uint32_t t3;
|
||||
|
||||
t1 = in & 0x7fff; // Non-sign bits
|
||||
t2 = in & 0x8000; // Sign bit
|
||||
t3 = in & 0x7c00; // Exponent
|
||||
|
||||
t1 <<= 13; // Align mantissa on MSB
|
||||
t2 <<= 16; // Shift sign bit into position
|
||||
|
||||
t1 += 0x38000000; // Adjust bias
|
||||
|
||||
t1 = (t3 == 0 ? 0 : t1); // Denormals-as-zero
|
||||
|
||||
t1 |= t2; // Re-insert sign bit
|
||||
|
||||
*((uint32_t*)&out) = t1;
|
||||
return out;
|
||||
}
|
||||
|
||||
// cast float to float16
|
||||
float16 cast_f32_to_f16(const float in)
|
||||
{
|
||||
float16 f16_out;
|
||||
|
||||
uint32_t inu = *((uint32_t*)&in);
|
||||
uint32_t t1;
|
||||
uint32_t t2;
|
||||
uint32_t t3;
|
||||
|
||||
t1 = inu & 0x7fffffff; // Non-sign bits
|
||||
t2 = inu & 0x80000000; // Sign bit
|
||||
t3 = inu & 0x7f800000; // Exponent
|
||||
|
||||
t1 >>= 13; // Align mantissa on MSB
|
||||
t2 >>= 16; // Shift sign bit into position
|
||||
|
||||
t1 -= 0x1c000; // Adjust bias
|
||||
|
||||
t1 = (t3 < 0x38800000) ? 0 : t1;
|
||||
t1 = (t3 > 0x47000000) ? 0x7bff : t1;
|
||||
t1 = (t3 == 0 ? 0 : t1); // Denormals-as-zero
|
||||
|
||||
t1 |= t2; // Re-insert sign bit
|
||||
|
||||
f16_out.v = t1;
|
||||
return f16_out;
|
||||
}
|
||||
|
||||
|
||||
// cast float to bfloat16
|
||||
bfloat16 cast_f32_to_bf16 (float val)
|
||||
{
|
||||
bfloat16 bf16;
|
||||
float32_s f32;
|
||||
f32.v = val;
|
||||
bf16.bits.s = f32.bits.s;
|
||||
bf16.bits.e = f32.bits.e;
|
||||
bf16.bits.m = f32.bits.m >> 16;
|
||||
return bf16;
|
||||
}
|
||||
|
||||
// cast bfloat16 to float
|
||||
float cast_bf16_to_f32(bfloat16 val)
|
||||
{
|
||||
float32_s f32;
|
||||
f32.bits.s = val.bits.s;
|
||||
f32.bits.e = val.bits.e;
|
||||
f32.bits.m = val.bits.m << 16;
|
||||
return f32.v;
|
||||
}
|
||||
|
||||
// cast a nibbles struct to a float array
|
||||
void cast_i4_to_f32(float *fvals, nibbles vals)
|
||||
{
|
||||
int8_t val0 = vals.bits.nib1;
|
||||
int8_t val1 = vals.bits.nib2;
|
||||
|
||||
val0 = (val0 >= 8 ? val0 - 16 : val0);
|
||||
val1 = (val1 >= 8 ? val1 - 16 : val1);
|
||||
|
||||
fvals[0] = (float) val0;
|
||||
fvals[1] = (float) val1;
|
||||
}
|
||||
|
||||
// condense two float vals to a nibbles struct
|
||||
nibbles cast_f32_to_i4(float val0, float val1)
|
||||
{
|
||||
nibbles vals;
|
||||
|
||||
int8_t val0_ = ((int8_t)val0) & 0xf0;
|
||||
int8_t val1_ = ((int8_t)val1) & 0xf0;
|
||||
|
||||
vals.bits.nib1 = val0_;
|
||||
vals.bits.nib2 = val1_;
|
||||
|
||||
return vals;
|
||||
}
|
||||
|
||||
// cast float matrix to float nibbles
|
||||
void cast_f32_to_i4m(float *a_float, nibbles *a, int num_elems)
|
||||
{
|
||||
int j=0;
|
||||
for(int i=0; i<num_elems; i+=2)
|
||||
{
|
||||
float val1 = a_float[i];
|
||||
float val0 = a_float[i+1];
|
||||
|
||||
a[j] = cast_f32_to_i4(val0, val1);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
|
||||
// cast nibbles matrix to float matrix
|
||||
void cast_i4_to_f32m(nibbles *a, float *a_float, int num_elems)
|
||||
{
|
||||
int j=0;
|
||||
float *fvals = (float *)malloc(2*sizeof(float));
|
||||
for(int i=0; i<num_elems; i+=2)
|
||||
{
|
||||
nibbles vals = a[j];
|
||||
j++;
|
||||
cast_i4_to_f32(fvals, vals);
|
||||
a_float[i] = fvals[0];
|
||||
a_float[i+1] = fvals[1];
|
||||
}
|
||||
free(fvals);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// cast single element using C casting
|
||||
|
||||
EASY_CAST_FUNC(f32, f32, float, float);
|
||||
EASY_CAST_FUNC(f32, i32, float, int);
|
||||
EASY_CAST_FUNC(f32, i16, float, int16_t);
|
||||
EASY_CAST_FUNC(f32, i8, float, int8_t);
|
||||
|
||||
EASY_CAST_FUNC(i32, f32, int, float);
|
||||
EASY_CAST_FUNC(i16, f32, int16_t, float);
|
||||
EASY_CAST_FUNC( i8, f32, int8_t, float);
|
||||
|
||||
|
||||
// cast entire matrix buffer
|
||||
|
||||
CASTING_MATRIX_FUNC(f32, f32, float, float, cast_f32_to_f32);
|
||||
CASTING_MATRIX_FUNC(f32, bf16, float, bfloat16, cast_f32_to_bf16);
|
||||
CASTING_MATRIX_FUNC(f32, f16, float, float16, cast_f32_to_f16);
|
||||
CASTING_MATRIX_FUNC(f32, i32, float, int, cast_f32_to_i32);
|
||||
CASTING_MATRIX_FUNC(f32, i16, float, int16_t, cast_f32_to_i16);
|
||||
CASTING_MATRIX_FUNC(f32, i8, float, int8_t, cast_f32_to_i8);
|
||||
|
||||
CASTING_MATRIX_FUNC(bf16, f32, bfloat16, float, cast_bf16_to_f32);
|
||||
CASTING_MATRIX_FUNC( f16, f32, float16, float, cast_f16_to_f32);
|
||||
CASTING_MATRIX_FUNC( i32, f32, int, float, cast_i32_to_f32);
|
||||
CASTING_MATRIX_FUNC( i16, f32, int16_t, float, cast_i16_to_f32);
|
||||
CASTING_MATRIX_FUNC( i8, f32, int8_t, float, cast_i8_to_f32);
|
||||
|
||||
62
sandbox/power10/p10_testsuite/cast_funcs.h
Normal file
62
sandbox/power10/p10_testsuite/cast_funcs.h
Normal file
@@ -0,0 +1,62 @@
|
||||
#include "blis.h"
|
||||
|
||||
#define EASY_CAST_FUNC_NAME_(ch_src, ch_dst) cast_ ## ch_src ## _to_ ## ch_dst
|
||||
#define EASY_CAST_FUNC_NAME(ch_src, ch_dst) EASY_CAST_FUNC_NAME_(ch_src, ch_dst)
|
||||
|
||||
#define CAST_MATRIX_FUNC_NAME_(ch_src, ch_dst) cast_ ## ch_src ## _to_ ## ch_dst ## m
|
||||
#define CAST_MATRIX_FUNC_NAME(ch_src, ch_dst) CAST_MATRIX_FUNC_NAME_(ch_src, ch_dst)
|
||||
|
||||
#define CAST_MATRIX_FUNC_PROTO(ch_src, ch_dst, src_t, dst_t) \
|
||||
void CAST_MATRIX_FUNC_NAME(ch_src, ch_dst) (src_t *, dst_t *, int)
|
||||
|
||||
#define EASY_CAST_FUNC_PROTO(ch_src, ch_dst, src_t, dst_t) \
|
||||
dst_t EASY_CAST_FUNC_NAME(ch_src, ch_dst) (src_t)
|
||||
|
||||
#define EASY_CAST_FUNC(ch_src, ch_dst, src_t, dst_t) \
|
||||
dst_t EASY_CAST_FUNC_NAME(ch_src, ch_dst) \
|
||||
(src_t val) { \
|
||||
return (dst_t) val; \
|
||||
}
|
||||
|
||||
#define CASTING_MATRIX_FUNC(ch_src, ch_dst, src_t, dst_t, cast_func) \
|
||||
void CAST_MATRIX_FUNC_NAME(ch_src, ch_dst) \
|
||||
(src_t *m1, dst_t *m2, int num_elems) { \
|
||||
for(int i=0;i<num_elems;i++) \
|
||||
m2[i] = cast_func (m1[i]); \
|
||||
}
|
||||
|
||||
float cast_bf16_to_f32(bfloat16 val);
|
||||
float cast_f16_to_f32(float16 val);
|
||||
|
||||
float16 cast_f32_to_f16(const float in);
|
||||
bfloat16 cast_f32_to_bf16 (float val);
|
||||
|
||||
void cast_i4_to_f32(float *fvals, nibbles val);
|
||||
nibbles cast_f32_to_i4(float val0, float val1);
|
||||
|
||||
void cast_f32_to_i4m(float *a_float, nibbles *a, int num_elems);
|
||||
void cast_i4_to_f32m(nibbles *a, float *a_float, int num_elems);
|
||||
|
||||
EASY_CAST_FUNC_PROTO(f32, f32, float, float);
|
||||
|
||||
EASY_CAST_FUNC_PROTO(f32, i32, float, int32_t);
|
||||
EASY_CAST_FUNC_PROTO(f32, i16, float, int16_t);
|
||||
EASY_CAST_FUNC_PROTO(f32, i8, float, int8_t);
|
||||
|
||||
EASY_CAST_FUNC_PROTO(i32, f32, int32_t, float);
|
||||
EASY_CAST_FUNC_PROTO(i16, f32, int16_t, float);
|
||||
EASY_CAST_FUNC_PROTO( i8, f32, int8_t, float);
|
||||
|
||||
CAST_MATRIX_FUNC_PROTO(f32, f32, float, float);
|
||||
CAST_MATRIX_FUNC_PROTO(f32, bf16, float, bfloat16);
|
||||
CAST_MATRIX_FUNC_PROTO(f32, f16, float, float16);
|
||||
|
||||
CAST_MATRIX_FUNC_PROTO(f32, i32, float, int32_t);
|
||||
CAST_MATRIX_FUNC_PROTO(f32, i16, float, int16_t);
|
||||
CAST_MATRIX_FUNC_PROTO(f32, i8, float, int8_t);
|
||||
|
||||
CAST_MATRIX_FUNC_PROTO(bf16, f32, bfloat16, float);
|
||||
CAST_MATRIX_FUNC_PROTO( f16, f32, float16, float);
|
||||
CAST_MATRIX_FUNC_PROTO( i32, f32, int32_t, float);
|
||||
CAST_MATRIX_FUNC_PROTO( i16, f32, int16_t, float);
|
||||
CAST_MATRIX_FUNC_PROTO( i8, f32, int8_t, float);
|
||||
16
sandbox/power10/p10_testsuite/common.h
Normal file
16
sandbox/power10/p10_testsuite/common.h
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
#ifndef COMMON_H
|
||||
#define COMMON_H
|
||||
|
||||
// enumerate all datatypes that will be tested
|
||||
enum DATATYPES {
|
||||
DOUBLE ,
|
||||
SINGLE ,
|
||||
FLOAT16 ,
|
||||
BFLOAT16,
|
||||
INT16 ,
|
||||
INT8 ,
|
||||
INT4
|
||||
};
|
||||
|
||||
#endif
|
||||
337
sandbox/power10/p10_testsuite/correctness.c
Normal file
337
sandbox/power10/p10_testsuite/correctness.c
Normal file
@@ -0,0 +1,337 @@
|
||||
/*
|
||||
|
||||
This program is designed to test the correctness of the POWER10 GEMM
|
||||
kernels in `blis/sandbox/power10`.
|
||||
|
||||
By default, the correctness of the kernels is determined by measuring how
|
||||
close the return value of the following function is to zero for square
|
||||
matrix sizes.
|
||||
|
||||
F(A, B, C_orig, C_ans, alpha, beta, t) =
|
||||
|
||||
normf( (C_ans * t) - ((alpha * A * B + beta * C_orig) * t) )
|
||||
|
||||
The function above can only be used to measure correctness if
|
||||
A, B, C_orig, and t have been randomized and normalized.
|
||||
|
||||
The correctness is reported by printing the function return value along
|
||||
with the matrices' sizes.
|
||||
|
||||
*/
|
||||
|
||||
|
||||
#include "blis.h"
|
||||
#include "cast_funcs.h"
|
||||
#include "correctness.h"
|
||||
#include "../bli_sandbox.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <stdio.h>
|
||||
// print kernel name
|
||||
const char* get_kernel_name(int kernel_id)
|
||||
{
|
||||
switch (kernel_id)
|
||||
{
|
||||
case FLOAT16 : return "bli_shgemm";
|
||||
case BFLOAT16: return "bli_sbgemm";
|
||||
case INT16 : return "bli_i16gemm";
|
||||
case INT8 : return "bli_i8gemm";
|
||||
case INT4 : return "bli_i4gemm";
|
||||
default: printf("INCORRECT KERNEL ID\n"); exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// normalize the vector using the forbenious norm
|
||||
void normalize_vec(float *t, int n)
|
||||
{
|
||||
// normalize t
|
||||
float norm_factor;
|
||||
bli_snormfv(n, t, 1, &norm_factor);
|
||||
// round up to closest power of 2
|
||||
norm_factor = 1 / (pow( 2.0, ceil( log2( norm_factor ) ) ));
|
||||
bli_sscalv(BLIS_NO_CONJUGATE, n, &norm_factor, t, 1);
|
||||
}
|
||||
|
||||
// Pre-conditions:
|
||||
// - a is randomized.
|
||||
// - b is randomized.
|
||||
// - c_orig is randomized.
|
||||
// Note:
|
||||
// - alpha and beta should have non-zero imaginary components in the
|
||||
// complex cases in order to more fully exercise the implementation.
|
||||
//
|
||||
// Under these conditions, we assume that the implementation for
|
||||
//
|
||||
// C := beta * C_orig + alpha * transa(A) * transb(B)
|
||||
//
|
||||
// is functioning correctly if
|
||||
//
|
||||
// normfv( v - z )
|
||||
//
|
||||
// is negligible, where
|
||||
//
|
||||
// v = C * t
|
||||
// z = ( beta * C_orig + alpha * transa(A) * transb(B) ) * t
|
||||
// = beta * C_orig * t + alpha * transa(A) * transb(B) * t
|
||||
// = beta * C_orig * t + alpha * transa(A) * w
|
||||
// = beta * C_orig * t + z
|
||||
float get_resid(
|
||||
int m, int n, int k,
|
||||
float *a, int rsa, int csa,
|
||||
float *b, int rsb, int csb,
|
||||
float *c, int rsc, int csc,
|
||||
float *c_orig,
|
||||
float *alpha, float *beta
|
||||
)
|
||||
{
|
||||
|
||||
float t[n], v[m], w[k], z[m];
|
||||
float one = 1.0, zero = 0.0;
|
||||
|
||||
bli_srandv(n, t, 1);
|
||||
|
||||
// normalize so that the values are at the same precision of the input values
|
||||
normalize_vec(t, n);
|
||||
|
||||
// v = C * t
|
||||
bli_sgemv(
|
||||
BLIS_NO_TRANSPOSE,
|
||||
BLIS_NO_CONJUGATE,
|
||||
m,
|
||||
n,
|
||||
&one,
|
||||
c, rsc, csc,
|
||||
t, 1,
|
||||
&zero,
|
||||
v, 1
|
||||
);
|
||||
|
||||
// w = B * t
|
||||
bli_sgemv(
|
||||
BLIS_NO_TRANSPOSE,
|
||||
BLIS_NO_CONJUGATE,
|
||||
k,
|
||||
n,
|
||||
&one,
|
||||
b, rsb, csb,
|
||||
t, 1,
|
||||
&zero,
|
||||
w, 1
|
||||
);
|
||||
|
||||
// z = alpha * A * w
|
||||
bli_sgemv(
|
||||
BLIS_NO_TRANSPOSE,
|
||||
BLIS_NO_CONJUGATE,
|
||||
m,
|
||||
k,
|
||||
alpha,
|
||||
a, rsa, csa,
|
||||
w, 1,
|
||||
&zero,
|
||||
z, 1
|
||||
);
|
||||
|
||||
// z += beta * C_orig * t
|
||||
bli_sgemv(
|
||||
BLIS_NO_TRANSPOSE,
|
||||
BLIS_NO_CONJUGATE,
|
||||
m,
|
||||
n,
|
||||
beta,
|
||||
c_orig, rsc, csc,
|
||||
t, 1,
|
||||
&one,
|
||||
z, 1
|
||||
);
|
||||
|
||||
// v = v - z
|
||||
bli_ssubv (
|
||||
BLIS_NO_CONJUGATE,
|
||||
m,
|
||||
z, 1,
|
||||
v, 1
|
||||
);
|
||||
|
||||
// norm = normfv(v)
|
||||
float norm;
|
||||
bli_snormfv (
|
||||
m,
|
||||
v, 1,
|
||||
&norm
|
||||
);
|
||||
|
||||
return norm;
|
||||
}
|
||||
|
||||
|
||||
// test to see if the result from a BLIS GEMM kernel is correct for a given m x n x k mat-mul
|
||||
// assumes the matrices are of type float
|
||||
// assumes the matrices were randomized and normalized
|
||||
void correctness_checker(
|
||||
int m, int n, int k,
|
||||
float *a, int rsa, int csa,
|
||||
float *b, int rsb, int csb,
|
||||
float *c_orig, int rsc, int csc,
|
||||
float *c_ans,
|
||||
float alpha, float beta
|
||||
)
|
||||
{
|
||||
double start, end;
|
||||
|
||||
start = bli_clock();
|
||||
float resid = get_resid (
|
||||
m, n, k,
|
||||
a, rsa, csa,
|
||||
b, rsb, csb,
|
||||
c_ans, rsc, csc,
|
||||
c_orig,
|
||||
&alpha, &beta
|
||||
);
|
||||
end = bli_clock();
|
||||
|
||||
printf("%d, %d, %d, %8.4le\n", m,n,k, resid);
|
||||
}
|
||||
|
||||
|
||||
// create all the correctness checking functions for each kernel
|
||||
GEN_FP_COR_KERNEL(sb, bli_sbgemm, bfloat16, cast_f32_to_bf16m, cast_bf16_to_f32m);
|
||||
GEN_FP_COR_KERNEL(sh, bli_shgemm, float16, cast_f32_to_f16m, cast_f16_to_f32m);
|
||||
GEN_I_COR_KERNEL(i16, bli_i16gemm, int16_t, cast_f32_to_i16m, cast_i16_to_f32m);
|
||||
GEN_I_COR_KERNEL(i8, bli_i8gemm, int8_t, cast_f32_to_i8m, cast_i8_to_f32m);
|
||||
|
||||
// correctness template for int types
|
||||
void i4correctness_kernel (int m, int n, int k)
|
||||
{
|
||||
if(n%2 != 0)
|
||||
{
|
||||
printf("int4 can't handle odd sizes in the data-order dimension");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int rsa = k, csa = 1,
|
||||
rsb = n, csb = 1,
|
||||
rsc = n, csc = 1;
|
||||
|
||||
nibbles *a, *b;
|
||||
|
||||
int32_t *c_ans, *c_orig, alpha, beta;
|
||||
|
||||
float *a_float, *b_float,
|
||||
*c_ans_float, *c_orig_float;
|
||||
|
||||
/* buffers that will be passed into the kernel */
|
||||
// int4 buffers only need half the space to store all the elements
|
||||
a = (nibbles *) malloc (m * (k/2) * sizeof(nibbles));
|
||||
b = (nibbles *) malloc (k * (n/2) * sizeof(nibbles));
|
||||
|
||||
c_ans = (int32_t *) malloc (m * n * sizeof(int32_t));
|
||||
c_orig = (int32_t *) malloc (m * n * sizeof(int32_t));
|
||||
|
||||
/* std format buffers that will be used by the correctness checker */
|
||||
a_float = (float *) malloc (m * k * sizeof(float));
|
||||
b_float = (float *) malloc (k * n * sizeof(float));
|
||||
c_ans_float = (float *) malloc (m * n * sizeof(float));
|
||||
c_orig_float = (float *) malloc (m * n * sizeof(float));
|
||||
|
||||
/* randomize matrices with float vals */
|
||||
bli_srandv(m*k, a_float, 1);
|
||||
bli_srandv(k*n, b_float, 1);
|
||||
bli_srandv(m*n, c_orig_float, 1);
|
||||
|
||||
/* normalize the matrices */
|
||||
normalize_vec(a_float, m*k);
|
||||
normalize_vec(b_float, k*n);
|
||||
normalize_vec(c_orig_float, m*n);
|
||||
|
||||
/* cast the float buffers into the buffers for the kernel */
|
||||
cast_f32_to_i4m (a_float, a, m*k);
|
||||
cast_f32_to_i4m (b_float, b, k*n);
|
||||
|
||||
/* cast float buffers to support int values */
|
||||
cast_f32_to_i32m(c_orig_float, c_orig, m*n);
|
||||
cast_i32_to_f32m(c_orig, c_orig_float, m*n);
|
||||
|
||||
/* cast the kernel buffers into the float buffers to ensure that the values match */
|
||||
cast_i4_to_f32m (a, a_float, m*k);
|
||||
cast_i4_to_f32m (b, b_float, k*n);
|
||||
|
||||
/* init alpha and beta */
|
||||
alpha = 1;
|
||||
beta = 1;
|
||||
|
||||
/* run kernel to get result in c_ans */
|
||||
// strides need to be adjusted since 1 element stores 2 values
|
||||
memcpy(c_ans, c_orig, m * n * sizeof(int));
|
||||
bli_i4gemm(
|
||||
BLIS_NO_TRANSPOSE,
|
||||
BLIS_NO_TRANSPOSE,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
&alpha,
|
||||
a, rsa/2, csa,
|
||||
b, rsb/2, csb,
|
||||
&beta,
|
||||
c_ans, rsc, csc
|
||||
);
|
||||
|
||||
/* cast integer result into float buffer since float is our std format for correctness checking */
|
||||
cast_i32_to_f32m(c_ans, c_ans_float, m*n);
|
||||
|
||||
/* using the BLIS GEMM correctness check method, get the resid */
|
||||
correctness_checker(
|
||||
m, n, k,
|
||||
a_float, rsa, csa,
|
||||
b_float, rsb, csb,
|
||||
c_orig_float, rsc, csc,
|
||||
c_ans_float,
|
||||
(float) alpha, (float) beta
|
||||
);
|
||||
|
||||
free(a);
|
||||
free(b);
|
||||
free(c_ans);
|
||||
free(c_orig);
|
||||
free(a_float);
|
||||
free(b_float);
|
||||
free(c_ans_float);
|
||||
free(c_orig_float);
|
||||
}
|
||||
|
||||
// using the DATATYPE enum, gather test the correctness of the respective GEMM kernel
|
||||
void run_correctness_kernel(int kernel_id, int m, int n, int k)
|
||||
{
|
||||
switch (kernel_id)
|
||||
{
|
||||
case FLOAT16 : shcorrectness_kernel(m, n, k); break;
|
||||
case BFLOAT16: sbcorrectness_kernel(m, n, k); break;
|
||||
case INT16 : i16correctness_kernel(m, n, k); break;
|
||||
case INT8 : i8correctness_kernel(m, n, k); break;
|
||||
case INT4 : i4correctness_kernel(m, n, k); break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
void test_correctness(int kernel_id, int start, int end, int inc)
|
||||
{
|
||||
printf("%s correctness test\n", get_kernel_name(kernel_id));
|
||||
printf("m, n, k, resid\n");
|
||||
int m,n,k;
|
||||
for (int p=start; p<=end; p+=inc)
|
||||
{
|
||||
m=n=k=p;
|
||||
run_correctness_kernel(kernel_id, m, n, k);
|
||||
}
|
||||
}
|
||||
|
||||
// correctness test for bfloat16 gemm
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
|
||||
test_correctness(FLOAT16, 80, 4000, 80);
|
||||
test_correctness(BFLOAT16, 80, 4000, 80);
|
||||
test_correctness(INT16, 80, 4000, 80);
|
||||
test_correctness(INT8, 80, 4000, 80);
|
||||
test_correctness(INT4, 80, 4000, 80);
|
||||
}
|
||||
176
sandbox/power10/p10_testsuite/correctness.h
Normal file
176
sandbox/power10/p10_testsuite/correctness.h
Normal file
@@ -0,0 +1,176 @@
|
||||
// templates for generating correctness checking functions that check the correctness of GEMM kernels
|
||||
// using the BLIS GEMM correctness method
|
||||
|
||||
#define COR_KERNEL_NAME_(ch) ch ## correctness_kernel
|
||||
#define COR_KERNEL_NAME(ch) COR_KERNEL_NAME_(ch)
|
||||
|
||||
|
||||
// correctness template for float types
|
||||
#define GEN_FP_COR_KERNEL(ch, kernel, input_t, DOWN_CAST, UP_CAST) \
|
||||
void COR_KERNEL_NAME(ch) (int m, int n, int k) \
|
||||
{ \
|
||||
int rsa = k, csa = 1, \
|
||||
rsb = n, csb = 1, \
|
||||
rsc = n, csc = 1; \
|
||||
\
|
||||
input_t *a, *b; \
|
||||
\
|
||||
float *a_float, *b_float, \
|
||||
*c_ans_float, *c_orig_float, \
|
||||
alpha, beta; \
|
||||
\
|
||||
/* buffers that will be passed into the kernel */ \
|
||||
a = (input_t *) malloc (m * k * sizeof(input_t)); \
|
||||
b = (input_t *) malloc (k * n * sizeof(input_t)); \
|
||||
\
|
||||
/* std format buffers that will be used by the correctness checker */ \
|
||||
a_float = (float *) malloc (m * k * sizeof(float)); \
|
||||
b_float = (float *) malloc (k * n * sizeof(float)); \
|
||||
c_ans_float = (float *) malloc (m * n * sizeof(float)); \
|
||||
c_orig_float = (float *) malloc (m * n * sizeof(float)); \
|
||||
\
|
||||
/* randomize matrices with float vals */ \
|
||||
bli_srandv(m*k, a_float, 1); \
|
||||
bli_srandv(k*n, b_float, 1); \
|
||||
bli_srandv(m*n, c_orig_float, 1); \
|
||||
\
|
||||
/* normalize the matrices */ \
|
||||
normalize_vec(a_float, m*k); \
|
||||
normalize_vec(b_float, k*n); \
|
||||
normalize_vec(c_orig_float, m*n); \
|
||||
\
|
||||
/* cast the float buffers into the buffers for the kernel */ \
|
||||
DOWN_CAST (a_float, a, m*k); \
|
||||
DOWN_CAST (b_float, b, k*n); \
|
||||
\
|
||||
/* cast the kernel buffers into the float buffers to ensure that the values match */ \
|
||||
UP_CAST (a, a_float, m*k); \
|
||||
UP_CAST (b, b_float, k*n); \
|
||||
\
|
||||
/* init alpha and beta */ \
|
||||
alpha = 1; \
|
||||
beta = 1; \
|
||||
\
|
||||
memcpy(c_ans_float, c_orig_float, m * n * sizeof(float)); \
|
||||
kernel( \
|
||||
BLIS_NO_TRANSPOSE, \
|
||||
BLIS_NO_TRANSPOSE, \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
&alpha, \
|
||||
a, rsa, csa, \
|
||||
b, rsb, csb, \
|
||||
&beta, \
|
||||
c_ans_float, rsc, csc \
|
||||
); \
|
||||
\
|
||||
correctness_checker( \
|
||||
m, n, k, \
|
||||
a_float, rsa, csa, \
|
||||
b_float, rsb, csb, \
|
||||
c_orig_float, rsc, csc, \
|
||||
c_ans_float, \
|
||||
alpha, beta \
|
||||
); \
|
||||
\
|
||||
free(a); \
|
||||
free(b); \
|
||||
free(a_float); \
|
||||
free(b_float); \
|
||||
free(c_ans_float); \
|
||||
free(c_orig_float); \
|
||||
\
|
||||
}
|
||||
|
||||
// correctness template for int types
|
||||
#define GEN_I_COR_KERNEL(ch, kernel, input_t, DOWN_CAST, UP_CAST) \
|
||||
void COR_KERNEL_NAME(ch) (int m, int n, int k) \
|
||||
{ \
|
||||
int rsa = k, csa = 1, \
|
||||
rsb = n, csb = 1, \
|
||||
rsc = n, csc = 1; \
|
||||
\
|
||||
input_t *a, *b; \
|
||||
\
|
||||
int32_t *c_ans, *c_orig, alpha, beta; \
|
||||
\
|
||||
float *a_float, *b_float, \
|
||||
*c_ans_float, *c_orig_float; \
|
||||
\
|
||||
/* buffers that will be passed into the kernel */ \
|
||||
a = (input_t *) malloc (m * k * sizeof(input_t)); \
|
||||
b = (input_t *) malloc (k * n * sizeof(input_t)); \
|
||||
c_ans = (int32_t *) malloc (m * n * sizeof(int32_t)); \
|
||||
c_orig = (int32_t *) malloc (m * n * sizeof(int32_t)); \
|
||||
\
|
||||
/* std format buffers that will be used by the correctness checker */ \
|
||||
a_float = (float *) malloc (m * k * sizeof(float)); \
|
||||
b_float = (float *) malloc (k * n * sizeof(float)); \
|
||||
c_ans_float = (float *) malloc (m * n * sizeof(float)); \
|
||||
c_orig_float = (float *) malloc (m * n * sizeof(float)); \
|
||||
\
|
||||
/* randomize matrices with float vals */ \
|
||||
bli_srandv(m*k, a_float, 1); \
|
||||
bli_srandv(k*n, b_float, 1); \
|
||||
bli_srandv(m*n, c_orig_float, 1); \
|
||||
\
|
||||
/* normalize the matrices */ \
|
||||
normalize_vec(a_float, m*k); \
|
||||
normalize_vec(b_float, k*n); \
|
||||
normalize_vec(c_orig_float, m*n); \
|
||||
\
|
||||
/* cast the float buffers into the buffers for the kernel */ \
|
||||
DOWN_CAST (a_float, a, m*k); \
|
||||
DOWN_CAST (b_float, b, k*n); \
|
||||
\
|
||||
/* cast float buffers to support int values */ \
|
||||
cast_f32_to_i32m(c_orig_float, c_orig, m*n); \
|
||||
cast_i32_to_f32m(c_orig, c_orig_float, m*n); \
|
||||
\
|
||||
/* cast the kernel buffers into the float buffers to ensure that the values match */ \
|
||||
UP_CAST (a, a_float, m*k); \
|
||||
UP_CAST (b, b_float, k*n); \
|
||||
\
|
||||
/* init alpha and beta */ \
|
||||
alpha = 1; \
|
||||
beta = 1; \
|
||||
\
|
||||
/* run kernel to get result in c_ans */ \
|
||||
memcpy(c_ans, c_orig, m * n * sizeof(int)); \
|
||||
kernel( \
|
||||
BLIS_NO_TRANSPOSE, \
|
||||
BLIS_NO_TRANSPOSE, \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
&alpha, \
|
||||
a, rsa, csa, \
|
||||
b, rsb, csb, \
|
||||
&beta, \
|
||||
c_ans, rsc, csc \
|
||||
); \
|
||||
\
|
||||
/* cast integer result into float buffer since float is our std format for correctness checking */ \
|
||||
cast_i32_to_f32m(c_ans, c_ans_float, m*n); \
|
||||
\
|
||||
/* using the BLIS GEMM correctness check method, get the resid */ \
|
||||
correctness_checker( \
|
||||
m, n, k, \
|
||||
a_float, rsa, csa, \
|
||||
b_float, rsb, csb, \
|
||||
c_orig_float, rsc, csc, \
|
||||
c_ans_float, \
|
||||
(float) alpha, (float) beta \
|
||||
); \
|
||||
\
|
||||
free(a); \
|
||||
free(b); \
|
||||
free(c_ans); \
|
||||
free(c_orig); \
|
||||
free(a_float); \
|
||||
free(b_float); \
|
||||
free(c_ans_float); \
|
||||
free(c_orig_float); \
|
||||
\
|
||||
}
|
||||
103
sandbox/power10/p10_testsuite/performance.c
Normal file
103
sandbox/power10/p10_testsuite/performance.c
Normal file
@@ -0,0 +1,103 @@
|
||||
/*
|
||||
|
||||
This program is designed to gather the performance data of the POWER10
|
||||
GEMM kernels in `blis/sandbox/power10`.
|
||||
|
||||
By default, the performance of the kernels is gather over a set of square
|
||||
matrices. The perfromance results are reported in GFLOPS, and outputted in
|
||||
CSV format.
|
||||
|
||||
*/
|
||||
|
||||
#include "performance.h"
|
||||
#include "blis.h"
|
||||
#include "../bli_sandbox.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <stdio.h>
|
||||
// print kernel name
|
||||
const char* get_kernel_name(int kernel_id)
|
||||
{
|
||||
switch (kernel_id)
|
||||
{
|
||||
case FLOAT16 : return "bli_shgemm";
|
||||
case BFLOAT16: return "bli_sbgemm";
|
||||
case INT16 : return "bli_i16gemm";
|
||||
case INT8 : return "bli_i8gemm";
|
||||
case INT4 : return "bli_i4gemm";
|
||||
default: printf("INCORRECT KERNEL ID\n"); exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// create all the performance gathering functions for each kernel
|
||||
GET_PERF_API_TEMP(sb, bli_sbgemm, bfloat16, float);
|
||||
GET_PERF_API_TEMP(sh, bli_shgemm, float16, float);
|
||||
GET_PERF_API_TEMP(i16, bli_i16gemm, int16_t, int);
|
||||
GET_PERF_API_TEMP(i8, bli_i8gemm, int8_t, int);
|
||||
GET_PERF_API_TEMP(i4, bli_i4gemm, nibbles, int);
|
||||
|
||||
|
||||
// using the DATATYPE enum, gather the performance of the respective GEMM kernel
|
||||
double run_kernel(int kernel_id, int nreps, int m, int n, int k)
|
||||
{
|
||||
switch (kernel_id)
|
||||
{
|
||||
case FLOAT16 : return test_shapi(nreps, m, n, k);
|
||||
case BFLOAT16: return test_sbapi(nreps, m, n, k);
|
||||
case INT16 : return test_i16api(nreps, m, n, k);
|
||||
case INT8 : return test_i8api(nreps, m, n, k);
|
||||
case INT4 : return test_i4api(nreps, m, n, k);
|
||||
default: return -1.0;
|
||||
}
|
||||
}
|
||||
|
||||
// print the performance data in CSV format
|
||||
// performance is measured in terms of GFLOPs
|
||||
void print_perf_data(int m, int n, int k, double best_time)
|
||||
{
|
||||
double GFLOPS = (2.0 * m * n * k) / (1e9 * best_time);
|
||||
printf("%d, %d, %d, %.2f\n", m, n, k, GFLOPS);
|
||||
}
|
||||
|
||||
// get performance data
|
||||
void get_perf(int kernel_id, int nreps, int start, int end, int inc)
|
||||
{
|
||||
// csv header
|
||||
printf("%s performance\n", get_kernel_name(kernel_id));
|
||||
printf("m, n, k, GFLOPS\n");
|
||||
|
||||
int m,n,k;
|
||||
|
||||
// run over all problem sizes
|
||||
for (int p=start; p<=end; p+=inc)
|
||||
{
|
||||
// change here to adjust problem size
|
||||
m = p,
|
||||
n = p,
|
||||
k = p;
|
||||
|
||||
double best_run_time = run_kernel(kernel_id, nreps, m, n, k);
|
||||
|
||||
print_perf_data(m, n, k, best_run_time);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
// initialize a square problem set range
|
||||
int start = 80;
|
||||
int end = 4000;
|
||||
int inc = 80;
|
||||
|
||||
// number of times the kernel will be run
|
||||
int nreps = 5;
|
||||
|
||||
// run a respective kernel
|
||||
get_perf( FLOAT16, nreps, start, end, inc);
|
||||
get_perf(BFLOAT16, nreps, start, end, inc);
|
||||
get_perf( INT16, nreps, start, end, inc);
|
||||
get_perf( INT8, nreps, start, end, inc);
|
||||
get_perf( INT4, nreps, start, end, inc);
|
||||
|
||||
return 0;
|
||||
}
|
||||
58
sandbox/power10/p10_testsuite/performance.h
Normal file
58
sandbox/power10/p10_testsuite/performance.h
Normal file
@@ -0,0 +1,58 @@
|
||||
|
||||
// function name template
|
||||
// each function that will gather perform will be named test_<ch>api
|
||||
#define GEN_PERF_FUNC_NAME_(ch) test_ ## ch ## api
|
||||
#define GEN_PERF_FUNC_NAME(ch) GEN_PERF_FUNC_NAME_(ch)
|
||||
|
||||
/*
|
||||
Macro template for getting the best GEMM kernel runtime out of `num_runs`
|
||||
for matrices of size (m x n x k).
|
||||
*/
|
||||
#define GET_PERF_API_TEMP(ch, kernel, input_t, output_t) \
|
||||
double GEN_PERF_FUNC_NAME(ch) ( \
|
||||
int num_runs, \
|
||||
int m, \
|
||||
int n, \
|
||||
int k \
|
||||
) \
|
||||
{ \
|
||||
input_t *A,*B; \
|
||||
output_t *C; \
|
||||
output_t alpha,beta; \
|
||||
\
|
||||
A = (input_t*) malloc(m*k*sizeof(input_t)); \
|
||||
B = (input_t*) malloc(n*k*sizeof(input_t)); \
|
||||
C = (output_t*) malloc(m*n*sizeof(output_t)); \
|
||||
\
|
||||
alpha = 1; \
|
||||
beta = 1; \
|
||||
\
|
||||
double best = 1e9; \
|
||||
\
|
||||
for (int irep=0; irep<num_runs; irep++) \
|
||||
{ \
|
||||
double start = bli_clock(); \
|
||||
kernel( \
|
||||
BLIS_NO_TRANSPOSE, \
|
||||
BLIS_NO_TRANSPOSE, \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
&alpha, \
|
||||
A, k, 1, \
|
||||
B, n, 1, \
|
||||
&beta, \
|
||||
C, n, 1 \
|
||||
); \
|
||||
double end = bli_clock(); \
|
||||
\
|
||||
best = bli_min(best, end-start); \
|
||||
} \
|
||||
\
|
||||
free(A); \
|
||||
free(B); \
|
||||
free(C); \
|
||||
\
|
||||
return best; \
|
||||
} \
|
||||
|
||||
426
sandbox/power10/pack_a_templates.h
Normal file
426
sandbox/power10/pack_a_templates.h
Normal file
@@ -0,0 +1,426 @@
|
||||
|
||||
|
||||
|
||||
#define k_even_apack_16(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + p_idx*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (p_idx+1)*cs_a ];
|
||||
|
||||
#define k_odd_apack_16(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-1)*cs_a ]; \
|
||||
memset(adest, 0, 2); \
|
||||
adest++;
|
||||
|
||||
#define pad_macro_16(dest_matrix) \
|
||||
memset(dest_matrix, 0, 4); \
|
||||
dest_matrix+=2;
|
||||
|
||||
#define BIT16_PACK_A(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, a) \
|
||||
( \
|
||||
dim_t MR, \
|
||||
int m, int k, \
|
||||
DTYPE_IN* ap, int rs_a, int cs_a, \
|
||||
DTYPE_IN* apack \
|
||||
) \
|
||||
{ \
|
||||
int k_odd = k%2; \
|
||||
int p_idx; \
|
||||
\
|
||||
DTYPE_IN* adest = apack; \
|
||||
for (int i=0; i<m; i+=MR) \
|
||||
{ \
|
||||
int ib = bli_min(MR, m-i); \
|
||||
if (ib == MR) /* Full size column height */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for (int p=0; p<(k/2); p++) \
|
||||
{ \
|
||||
k_even_apack_16(0); \
|
||||
k_even_apack_16(1); \
|
||||
k_even_apack_16(2); \
|
||||
k_even_apack_16(3); \
|
||||
k_even_apack_16(4); \
|
||||
k_even_apack_16(5); \
|
||||
k_even_apack_16(6); \
|
||||
k_even_apack_16(7); \
|
||||
p_idx += 2; \
|
||||
} \
|
||||
\
|
||||
/* In the case that k is odd, we must pad with 0s */ \
|
||||
if(k_odd) \
|
||||
{ \
|
||||
k_odd_apack_16(0); \
|
||||
k_odd_apack_16(1); \
|
||||
k_odd_apack_16(2); \
|
||||
k_odd_apack_16(3); \
|
||||
k_odd_apack_16(4); \
|
||||
k_odd_apack_16(5); \
|
||||
k_odd_apack_16(6); \
|
||||
k_odd_apack_16(7); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else /* Not full size, pad with zeros */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for (int p=0; p<(k/2); p++) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_even_apack_16(ir); \
|
||||
} \
|
||||
for (int ir=ib; ir<MR; ir++) \
|
||||
{ \
|
||||
pad_macro_16(adest); \
|
||||
} \
|
||||
p_idx += 2; \
|
||||
} \
|
||||
\
|
||||
if(k_odd) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_odd_apack_16(ir); \
|
||||
} \
|
||||
for (int ir=ib; ir<MR; ir++) \
|
||||
{ \
|
||||
pad_macro_16(adest); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
/* 8 bit packing routines */
|
||||
|
||||
#define k_even_apack_8(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + p_idx*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (p_idx+1)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (p_idx+2)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (p_idx+3)*cs_a ];
|
||||
|
||||
#define k_left3_apack_8(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-3)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-2)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-1)*cs_a ]; \
|
||||
memset(adest, 0, 1); \
|
||||
adest++;
|
||||
|
||||
#define k_left2_apack_8(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-2)*cs_a ]; \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-1)*cs_a ]; \
|
||||
memset(adest, 0, 2); \
|
||||
adest += 2;
|
||||
|
||||
#define k_left1_apack_8(ir) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + (k-1)*cs_a ]; \
|
||||
memset(adest, 0, 3); \
|
||||
adest += 3;
|
||||
|
||||
#define pad_macro_8(dest_matrix) \
|
||||
memset(dest_matrix, 0, 4); \
|
||||
dest_matrix += 4;
|
||||
|
||||
|
||||
#define BIT8_PACK_A(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, a) \
|
||||
( \
|
||||
dim_t MR, \
|
||||
int m, int k, \
|
||||
DTYPE_IN* ap, int rs_a, int cs_a, \
|
||||
DTYPE_IN* apack \
|
||||
) \
|
||||
{ \
|
||||
int k_left = k%4; \
|
||||
int k_iter = k/4; \
|
||||
int p_idx; \
|
||||
\
|
||||
DTYPE_IN* adest = apack; \
|
||||
\
|
||||
/* Each panel must be packed in this format */ \
|
||||
for (int i=0; i<m; i+=MR) \
|
||||
{ \
|
||||
int ib = bli_min(MR, m-i); \
|
||||
\
|
||||
if (ib == MR) /* Full size column height */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for (int p=0; p<k_iter; p++) \
|
||||
{ \
|
||||
k_even_apack_8(0); \
|
||||
k_even_apack_8(1); \
|
||||
k_even_apack_8(2); \
|
||||
k_even_apack_8(3); \
|
||||
k_even_apack_8(4); \
|
||||
k_even_apack_8(5); \
|
||||
k_even_apack_8(6); \
|
||||
k_even_apack_8(7); \
|
||||
p_idx += 4; \
|
||||
} \
|
||||
\
|
||||
/* In the case that k is odd, we must pad with 0s */ \
|
||||
if(k_left==3) \
|
||||
{ \
|
||||
k_left3_apack_8(0); \
|
||||
k_left3_apack_8(1); \
|
||||
k_left3_apack_8(2); \
|
||||
k_left3_apack_8(3); \
|
||||
k_left3_apack_8(4); \
|
||||
k_left3_apack_8(5); \
|
||||
k_left3_apack_8(6); \
|
||||
k_left3_apack_8(7); \
|
||||
} \
|
||||
else if(k_left==2) \
|
||||
{ \
|
||||
k_left2_apack_8(0); \
|
||||
k_left2_apack_8(1); \
|
||||
k_left2_apack_8(2); \
|
||||
k_left2_apack_8(3); \
|
||||
k_left2_apack_8(4); \
|
||||
k_left2_apack_8(5); \
|
||||
k_left2_apack_8(6); \
|
||||
k_left2_apack_8(7); \
|
||||
} \
|
||||
else if(k_left==1) \
|
||||
{ \
|
||||
k_left1_apack_8(0); \
|
||||
k_left1_apack_8(1); \
|
||||
k_left1_apack_8(2); \
|
||||
k_left1_apack_8(3); \
|
||||
k_left1_apack_8(4); \
|
||||
k_left1_apack_8(5); \
|
||||
k_left1_apack_8(6); \
|
||||
k_left1_apack_8(7); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else /* Not full size, pad with zeros */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for (int p=0; p<k_iter; p++) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_even_apack_8(ir); \
|
||||
} \
|
||||
for (int ir=ib; ir<MR; ir++) \
|
||||
{ \
|
||||
pad_macro_8(adest); \
|
||||
} \
|
||||
p_idx += 4; \
|
||||
} \
|
||||
\
|
||||
if(k_left==3) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_left3_apack_8(ir); \
|
||||
} \
|
||||
} \
|
||||
else if(k_left==2) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_left2_apack_8(ir); \
|
||||
} \
|
||||
} \
|
||||
else if(k_left==1) \
|
||||
{ \
|
||||
for (int ir=0; ir<ib; ir++) \
|
||||
{ \
|
||||
k_left1_apack_8(ir); \
|
||||
} \
|
||||
} \
|
||||
if(k_left!=0) \
|
||||
{ \
|
||||
for (int ir=ib; ir<MR; ir++) { \
|
||||
pad_macro_8(adest); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/* Packing Routines */
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
|
||||
Memory is byte-addressed. This results in two options when dealing with
|
||||
int4. Either store 1 int4 value in a byte, or store 2 int4 values in 1
|
||||
byte. The former is wasteful in storage, but it makes for a simpler
|
||||
packing routine. However, we want to not waste any storage if possible.
|
||||
Therefore I went with the latter when designing my int4 kernel.
|
||||
|
||||
The int4 outerproduct instruction expects a 4x8 matrix in row major order
|
||||
to be loaded into the vector. In order to achieve this 4x8 row major
|
||||
matrix, we pack as many 4x8 panels from the src matrix into the pack matrix.
|
||||
|
||||
To illustrate how my packing routine works:
|
||||
|
||||
x0 x1 x2 x3 x4 x5 x6 x7
|
||||
x9 x10 x11 x12 x13 x14 x15 x16
|
||||
x17 x18 x19 x20 x21 x22 x23 x24
|
||||
x25 x26 x27 x28 x29 x30 x31 x32
|
||||
|
||||
Assume we have a 4x8 matrix that is stored in column major order. Also
|
||||
since we are dealing with int4 values, the values are stored as pairs
|
||||
within a union struct. i.e. (x0, x9) are stored together in the same struct.
|
||||
|
||||
Therefore in order to get the desired 4x8 row major matrix, we must go
|
||||
through the first row of structs and grab the first int4 value and insert
|
||||
it into the appropriate spot in the pack matrix. This means that after
|
||||
packing, (x0, x1) will be stored together in the same struct.
|
||||
|
||||
This process then repeats until the entire src matrix is packed in these
|
||||
4x8 row major matrix panels.
|
||||
|
||||
To handle edge cases, the packing routine will fill in zeros where it is
|
||||
appropriate.
|
||||
|
||||
*/
|
||||
|
||||
#include "i4_macros.h"
|
||||
|
||||
#define PACK_A(ch) \
|
||||
void PACK_FUNC_NAME(ch, a) \
|
||||
( \
|
||||
dim_t MR, \
|
||||
int m, int k, \
|
||||
uint32_t* ap, int rs_a, int cs_a, \
|
||||
uint32_t* apack \
|
||||
) \
|
||||
{ \
|
||||
uint32_t* restrict adest = apack; \
|
||||
for( int i=0; i<m; i += MR ) \
|
||||
{ \
|
||||
int ib = min(MR, m-i); \
|
||||
if ( ib == MR ) { \
|
||||
for ( int p=0; p<k; p++ ) \
|
||||
for ( int ir=0; ir<MR; ir++ ) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + p*cs_a ]; \
|
||||
} \
|
||||
else { \
|
||||
for ( int p=0; p<k; p++ ) { \
|
||||
for ( int ir=0; ir<ib; ir++ ) \
|
||||
*adest++ = ap[ (i+ir)*rs_a + p*cs_a ]; \
|
||||
for ( int ir=ib; ir<MR; ir++ ) \
|
||||
*adest++ = 0.0; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
#define BIT4_PACK_A(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, a) \
|
||||
( \
|
||||
dim_t MR, \
|
||||
int m, int k, \
|
||||
DTYPE_IN* ap, int rs_a, int cs_a, \
|
||||
DTYPE_IN* apack \
|
||||
) \
|
||||
{ \
|
||||
int p_idx, k_left, k_iter; \
|
||||
DTYPE_IN* adest = apack; \
|
||||
\
|
||||
k_left = k%8; \
|
||||
k_iter = k/8; \
|
||||
\
|
||||
int i = 0; /* i is used for byte addressing */ \
|
||||
for(int int4_i=0; int4_i<m; int4_i+=MR) { /* pack panels */ \
|
||||
\
|
||||
int ib = bli_min(MR, m-int4_i); \
|
||||
p_idx = 0; \
|
||||
\
|
||||
if (ib == MR) { /* full size */ \
|
||||
for (int p=0; p<k_iter; p++) { \
|
||||
col_m_order_1(adest, ap, (i+0), rs_a, cs_a); \
|
||||
col_m_order_2(adest, ap, (i+0), rs_a, cs_a); \
|
||||
col_m_order_1(adest, ap, (i+1), rs_a, cs_a); \
|
||||
col_m_order_2(adest, ap, (i+1), rs_a, cs_a); \
|
||||
col_m_order_1(adest, ap, (i+2), rs_a, cs_a); \
|
||||
col_m_order_2(adest, ap, (i+2), rs_a, cs_a); \
|
||||
col_m_order_1(adest, ap, (i+3), rs_a, cs_a); \
|
||||
col_m_order_2(adest, ap, (i+3), rs_a, cs_a); \
|
||||
p_idx += 8; \
|
||||
} \
|
||||
\
|
||||
/* handle edge cases if there are any */ \
|
||||
if(k_left == 7) { \
|
||||
apad_col_kleft7(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 6) { \
|
||||
apad_col_kleft6(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 5) { \
|
||||
apad_col_kleft5(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 4) { \
|
||||
apad_col_kleft4(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 3) { \
|
||||
apad_col_kleft3(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 2) { \
|
||||
apad_col_kleft2(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 1) { \
|
||||
apad_col_kleft1(adest, ap, rs_a, cs_a); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else { /* not full size */ \
|
||||
for (int p=0; p<k_iter; p++) { \
|
||||
for (int ir=0; ir<ib; ir++) { \
|
||||
if (ir%2==0) { \
|
||||
col_m_order_1(adest, ap, (i+ir/2), rs_a, cs_a); \
|
||||
} \
|
||||
else { \
|
||||
col_m_order_2(adest, ap, (i+ir/2), rs_a, cs_a); \
|
||||
} \
|
||||
} \
|
||||
for (int ir=ib; ir<MR; ir++) { \
|
||||
zero_out_dest(adest); \
|
||||
} \
|
||||
p_idx += 8; \
|
||||
} \
|
||||
\
|
||||
/* handle edge cases if there are any */ \
|
||||
if(k_left == 7) { \
|
||||
edge7(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 6) { \
|
||||
edge6(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 5) { \
|
||||
edge5(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 4) { \
|
||||
edge4(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 3) { \
|
||||
edge3(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 2) { \
|
||||
edge2(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
else if(k_left == 1) { \
|
||||
edge1(adest, ap, i, ib, rs_a, cs_a); \
|
||||
} \
|
||||
\
|
||||
/* fill in zeros when an edge case occurs */ \
|
||||
if(k_left!=0) \
|
||||
{ \
|
||||
for (int ir=ib; ir<MR; ir++) \
|
||||
zero_out_dest(adest); \
|
||||
} \
|
||||
} \
|
||||
i += (MR/2); \
|
||||
} \
|
||||
}
|
||||
403
sandbox/power10/pack_b_templates.h
Normal file
403
sandbox/power10/pack_b_templates.h
Normal file
@@ -0,0 +1,403 @@
|
||||
|
||||
|
||||
|
||||
#define k_even_bpack_16(jr) \
|
||||
*bdest++ = bp[ p_idx*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (p_idx+1)*rs_b + (j+jr)*cs_b ]; \
|
||||
|
||||
#define k_odd_bpack_16(jr) \
|
||||
*bdest++ = bp[ (k-1)*rs_b + (j+jr)*cs_b ]; \
|
||||
memset(bdest, 0, 2); \
|
||||
bdest++; \
|
||||
|
||||
#define BIT16_PACK_B(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, b) \
|
||||
( \
|
||||
dim_t NR, \
|
||||
int k, int n, \
|
||||
DTYPE_IN* bp, int rs_b, int cs_b, \
|
||||
DTYPE_IN* bpack \
|
||||
) \
|
||||
{ \
|
||||
\
|
||||
int k_odd = k%2; \
|
||||
int p_idx; \
|
||||
\
|
||||
DTYPE_IN* bdest = bpack; \
|
||||
\
|
||||
for( int j=0; j<n; j += NR ) \
|
||||
{ \
|
||||
int jb = bli_min(NR, n-j); \
|
||||
\
|
||||
if ( jb == NR ) /* Full column width micro-panel.*/ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for ( int p=0; p<(k/2); p++ ) \
|
||||
{ \
|
||||
k_even_bpack_16(0); \
|
||||
k_even_bpack_16(1); \
|
||||
k_even_bpack_16(2); \
|
||||
k_even_bpack_16(3); \
|
||||
k_even_bpack_16(4); \
|
||||
k_even_bpack_16(5); \
|
||||
k_even_bpack_16(6); \
|
||||
k_even_bpack_16(7); \
|
||||
k_even_bpack_16(8); \
|
||||
k_even_bpack_16(9); \
|
||||
k_even_bpack_16(10); \
|
||||
k_even_bpack_16(11); \
|
||||
k_even_bpack_16(12); \
|
||||
k_even_bpack_16(13); \
|
||||
k_even_bpack_16(14); \
|
||||
k_even_bpack_16(15); \
|
||||
p_idx += 2; \
|
||||
} \
|
||||
\
|
||||
/* In the case that k is odd, we must pad with 0s */ \
|
||||
if(k_odd) \
|
||||
{ \
|
||||
k_odd_bpack_16(0); \
|
||||
k_odd_bpack_16(1); \
|
||||
k_odd_bpack_16(2); \
|
||||
k_odd_bpack_16(3); \
|
||||
k_odd_bpack_16(4); \
|
||||
k_odd_bpack_16(5); \
|
||||
k_odd_bpack_16(6); \
|
||||
k_odd_bpack_16(7); \
|
||||
k_odd_bpack_16(8); \
|
||||
k_odd_bpack_16(9); \
|
||||
k_odd_bpack_16(10); \
|
||||
k_odd_bpack_16(11); \
|
||||
k_odd_bpack_16(12); \
|
||||
k_odd_bpack_16(13); \
|
||||
k_odd_bpack_16(14); \
|
||||
k_odd_bpack_16(15); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else /* Not a full row size micro-panel. We pad with zeroes. */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for ( int p=0; p<(k/2); p++ ) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_even_bpack_16(jr); \
|
||||
} \
|
||||
for ( int jr=jb; jr<NR; jr++ ) \
|
||||
{ \
|
||||
pad_macro_16(bdest); \
|
||||
} \
|
||||
p_idx += 2; \
|
||||
} \
|
||||
\
|
||||
if(k_odd) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_odd_bpack_16(jr); \
|
||||
} \
|
||||
for ( int jr=jb; jr<NR; jr++ ) \
|
||||
{ \
|
||||
pad_macro_16(bdest); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
|
||||
|
||||
#define k_even_bpack_8(jr) \
|
||||
*bdest++ = bp[ p_idx*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (p_idx+1)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (p_idx+2)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (p_idx+3)*rs_b + (j+jr)*cs_b ];
|
||||
|
||||
#define k_left3_bpack_8(jr) \
|
||||
*bdest++ = bp[ (k-3)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (k-2)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (k-1)*rs_b + (j+jr)*cs_b ]; \
|
||||
memset(bdest, 0, 1); \
|
||||
bdest++;
|
||||
|
||||
#define k_left2_bpack_8(jr) \
|
||||
*bdest++ = bp[ (k-2)*rs_b + (j+jr)*cs_b ]; \
|
||||
*bdest++ = bp[ (k-1)*rs_b + (j+jr)*cs_b ]; \
|
||||
memset(bdest, 0, 2); \
|
||||
bdest+=2;
|
||||
|
||||
#define k_left1_bpack_8(jr) \
|
||||
*bdest++ = bp[ (k-1)*rs_b + (j+jr)*cs_b ]; \
|
||||
memset(bdest, 0, 3); \
|
||||
bdest+=3;
|
||||
|
||||
|
||||
#define BIT8_PACK_B(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, b) \
|
||||
( \
|
||||
dim_t NR, \
|
||||
int k, int n, \
|
||||
DTYPE_IN* bp, int rs_b, int cs_b, \
|
||||
DTYPE_IN* bpack \
|
||||
) \
|
||||
{ \
|
||||
int k_left = k%4; \
|
||||
int k_iter = k/4; \
|
||||
int p_idx; \
|
||||
\
|
||||
DTYPE_IN* bdest = bpack; \
|
||||
\
|
||||
for( int j=0; j<n; j += NR ) \
|
||||
{ \
|
||||
int jb = bli_min(NR, n-j); \
|
||||
\
|
||||
if ( jb == NR ) /* Full column width micro-panel.*/ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for ( int p=0; p<k_iter; p++ ) \
|
||||
{ \
|
||||
k_even_bpack_8(0); \
|
||||
k_even_bpack_8(1); \
|
||||
k_even_bpack_8(2); \
|
||||
k_even_bpack_8(3); \
|
||||
k_even_bpack_8(4); \
|
||||
k_even_bpack_8(5); \
|
||||
k_even_bpack_8(6); \
|
||||
k_even_bpack_8(7); \
|
||||
k_even_bpack_8(8); \
|
||||
k_even_bpack_8(9); \
|
||||
k_even_bpack_8(10); \
|
||||
k_even_bpack_8(11); \
|
||||
k_even_bpack_8(12); \
|
||||
k_even_bpack_8(13); \
|
||||
k_even_bpack_8(14); \
|
||||
k_even_bpack_8(15); \
|
||||
p_idx += 4; \
|
||||
} \
|
||||
\
|
||||
if(k_left==3) \
|
||||
{ \
|
||||
k_left3_bpack_8(0); \
|
||||
k_left3_bpack_8(1); \
|
||||
k_left3_bpack_8(2); \
|
||||
k_left3_bpack_8(3); \
|
||||
k_left3_bpack_8(4); \
|
||||
k_left3_bpack_8(5); \
|
||||
k_left3_bpack_8(6); \
|
||||
k_left3_bpack_8(7); \
|
||||
k_left3_bpack_8(8); \
|
||||
k_left3_bpack_8(9); \
|
||||
k_left3_bpack_8(10); \
|
||||
k_left3_bpack_8(11); \
|
||||
k_left3_bpack_8(12); \
|
||||
k_left3_bpack_8(13); \
|
||||
k_left3_bpack_8(14); \
|
||||
k_left3_bpack_8(15); \
|
||||
} \
|
||||
else if(k_left==2) \
|
||||
{ \
|
||||
k_left2_bpack_8(0); \
|
||||
k_left2_bpack_8(1); \
|
||||
k_left2_bpack_8(2); \
|
||||
k_left2_bpack_8(3); \
|
||||
k_left2_bpack_8(4); \
|
||||
k_left2_bpack_8(5); \
|
||||
k_left2_bpack_8(6); \
|
||||
k_left2_bpack_8(7); \
|
||||
k_left2_bpack_8(8); \
|
||||
k_left2_bpack_8(9); \
|
||||
k_left2_bpack_8(10); \
|
||||
k_left2_bpack_8(11); \
|
||||
k_left2_bpack_8(12); \
|
||||
k_left2_bpack_8(13); \
|
||||
k_left2_bpack_8(14); \
|
||||
k_left2_bpack_8(15); \
|
||||
} \
|
||||
else if(k_left==1) \
|
||||
{ \
|
||||
k_left1_bpack_8(0); \
|
||||
k_left1_bpack_8(1); \
|
||||
k_left1_bpack_8(2); \
|
||||
k_left1_bpack_8(3); \
|
||||
k_left1_bpack_8(4); \
|
||||
k_left1_bpack_8(5); \
|
||||
k_left1_bpack_8(6); \
|
||||
k_left1_bpack_8(7); \
|
||||
k_left1_bpack_8(8); \
|
||||
k_left1_bpack_8(9); \
|
||||
k_left1_bpack_8(10); \
|
||||
k_left1_bpack_8(11); \
|
||||
k_left1_bpack_8(12); \
|
||||
k_left1_bpack_8(13); \
|
||||
k_left1_bpack_8(14); \
|
||||
k_left1_bpack_8(15); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
else /* Not a full row size micro-panel. We pad with zeroes. */ \
|
||||
{ \
|
||||
p_idx = 0; \
|
||||
for ( int p=0; p<k_iter; p++ ) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_even_bpack_8(jr); \
|
||||
} \
|
||||
for ( int jr=jb; jr<NR; jr++ ) \
|
||||
{ \
|
||||
pad_macro_8(bdest); \
|
||||
} \
|
||||
p_idx += 4; \
|
||||
} \
|
||||
\
|
||||
if(k_left==3) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_left3_bpack_8(jr); \
|
||||
} \
|
||||
} \
|
||||
else if(k_left==2) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_left2_bpack_8(jr); \
|
||||
} \
|
||||
} \
|
||||
else if(k_left==1) \
|
||||
{ \
|
||||
for ( int jr=0; jr<jb; jr++ ) \
|
||||
{ \
|
||||
k_left1_bpack_8(jr); \
|
||||
} \
|
||||
} \
|
||||
if (k_left!=0) \
|
||||
{ \
|
||||
for ( int jr=jb; jr<NR; jr++ ) { \
|
||||
pad_macro_8(bdest); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
#include "i4_macros.h"
|
||||
|
||||
#define BIT4_PACK_B(ch, DTYPE_IN) \
|
||||
\
|
||||
void PACK_FUNC_NAME(ch, b) \
|
||||
( \
|
||||
dim_t NR, \
|
||||
int k, int n, \
|
||||
DTYPE_IN* bp, int rs_b, int cs_b, \
|
||||
DTYPE_IN* bpack \
|
||||
) \
|
||||
{ \
|
||||
\
|
||||
int p_idx, k_left, k_iter; \
|
||||
DTYPE_IN* bdest = bpack; \
|
||||
\
|
||||
k_left = k%8; \
|
||||
k_iter = k/8; \
|
||||
\
|
||||
int j = 0; \
|
||||
\
|
||||
for(int int4_j=0; int4_j<n; int4_j+=NR) { /* pack panels */ \
|
||||
int jb = bli_min(NR, n-int4_j); \
|
||||
\
|
||||
p_idx = 0; \
|
||||
if (jb == NR) { /* full size */ \
|
||||
for (int p=0; p<k_iter; p++) { \
|
||||
col_m_order_1(bdest, bp, (j+0), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+0), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+1), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+1), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+2), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+2), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+3), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+3), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+4), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+4), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+5), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+5), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+6), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+6), cs_b, rs_b); \
|
||||
col_m_order_1(bdest, bp, (j+7), cs_b, rs_b); \
|
||||
col_m_order_2(bdest, bp, (j+7), cs_b, rs_b); \
|
||||
p_idx += 8; \
|
||||
} \
|
||||
\
|
||||
/* handle edge cases if there are any */ \
|
||||
if(k_left == 7) { \
|
||||
bpad_col_kleft7(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 6) { \
|
||||
bpad_col_kleft6(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 5) { \
|
||||
bpad_col_kleft5(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 4) { \
|
||||
bpad_col_kleft4(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 3) { \
|
||||
bpad_col_kleft3(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 2) { \
|
||||
bpad_col_kleft2(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 1) { \
|
||||
bpad_col_kleft1(bdest, bp, cs_b, rs_b); \
|
||||
} \
|
||||
} \
|
||||
else { /* not full size */ \
|
||||
for (int p=0; p<k_iter; p++) { \
|
||||
for (int jr=0; jr<jb; jr++) { \
|
||||
if (jr%2==0) { \
|
||||
col_m_order_1(bdest, bp, (j+jr/2), cs_b, rs_b); \
|
||||
} \
|
||||
else { \
|
||||
col_m_order_2(bdest, bp, (j+jr/2), cs_b, rs_b); \
|
||||
} \
|
||||
} \
|
||||
for (int jr=jb; jr<NR; jr++) { \
|
||||
zero_out_dest(bdest); \
|
||||
} \
|
||||
p_idx += 8; \
|
||||
} \
|
||||
\
|
||||
/* handle edge cases if there are any */ \
|
||||
if(k_left == 7) { \
|
||||
edge7(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 6) { \
|
||||
edge6(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 5) { \
|
||||
edge5(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 4) { \
|
||||
edge4(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 3) { \
|
||||
edge3(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 2) { \
|
||||
edge2(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
else if(k_left == 1) { \
|
||||
edge1(bdest, bp, j, jb, cs_b, rs_b); \
|
||||
} \
|
||||
\
|
||||
/* fill in zeros when an edge case occurs */ \
|
||||
if(k_left!=0) \
|
||||
{ \
|
||||
for (int ir=jb; ir<NR; ir++) \
|
||||
zero_out_dest(bdest); \
|
||||
} \
|
||||
} \
|
||||
j += (NR/2); \
|
||||
} \
|
||||
}
|
||||
@@ -32,27 +32,57 @@
|
||||
|
||||
*/
|
||||
|
||||
// Prototypes and template for the 5-loop gemm algorithm
|
||||
/*
|
||||
|
||||
Details on bit16_dt vector data structure
|
||||
|
||||
Vector X = [ X[0,0] X[0,1] X[1,0] X[1,1] X[2,0] X[2,1] X[3,0] X[3,1] ]
|
||||
Vector Y = [ Y[0,0] Y[0,1] Y[1,0] Y[1,1] Y[2,0] Y[2,1] Y[3,0] Y[3,1] ]
|
||||
|
||||
These bit16_dt vectors represent a 4x2 matrix. Hence, in matrix form it
|
||||
looks like the following:
|
||||
|
||||
X = [ X[0,0] X[0,1]
|
||||
X[1,0] X[1,1]
|
||||
X[2,0] X[2,1]
|
||||
X[3,0] X[3,1] ]
|
||||
|
||||
The outer product instruction: xvbf16ger2 (bfloat16 outer product)
|
||||
|
||||
Syntax:
|
||||
|
||||
xvbf16ger2 ACCUMULATOR A, VECTOR X, VECTOR Y
|
||||
|
||||
Semantics:
|
||||
|
||||
A = X * Y^T
|
||||
|
||||
The generic packing routine would load 8 elements from the same column.
|
||||
This causes an issue since the instruction expects the vector to be a
|
||||
4x2 matrix where the data is packed in contiguous order. Thus, we must make
|
||||
a packing routine that will interleave the matrix data. Making it so
|
||||
that when we load the 8 contiguous elements from A, it will represent
|
||||
a 4x2 section of the matrix.
|
||||
|
||||
*/
|
||||
|
||||
#include "pack_a_templates.h"
|
||||
#include "pack_b_templates.h"
|
||||
#include "bli_sandbox.h"
|
||||
|
||||
#define GEMM_PASTEMAC_(ch) bli_ ## ch ## gemm_
|
||||
#define GEMM_PASTEMAC(ch) GEMM_PASTEMAC_(ch)
|
||||
// 16 bit routines
|
||||
BIT16_PACK_A(sb, bfloat16);
|
||||
BIT16_PACK_B(sb, bfloat16);
|
||||
BIT16_PACK_A(sh, float16);
|
||||
BIT16_PACK_B(sh, float16);
|
||||
BIT16_PACK_A(i16, int16_t);
|
||||
BIT16_PACK_B(i16, int16_t);
|
||||
|
||||
#define GENERIC_GEMM_PROTO(ch, DTYPE_IN, DTYPE_OUT) \
|
||||
void GEMM_PASTEMAC(ch) \
|
||||
( \
|
||||
dim_t MR, dim_t NR, dim_t KC, dim_t NC, dim_t MC, \
|
||||
int m, int n, int k, \
|
||||
DTYPE_IN* restrict A, int rs_a, int cs_a, int A_align, \
|
||||
DTYPE_IN* restrict B, int rs_b, int cs_b, int B_align, \
|
||||
DTYPE_OUT* restrict C, int rs_c, int cs_c, \
|
||||
DTYPE_OUT* alpha, DTYPE_OUT* beta \
|
||||
)
|
||||
// 8 bit
|
||||
BIT8_PACK_A(i8, int8_t);
|
||||
BIT8_PACK_B(i8, int8_t);
|
||||
|
||||
GENERIC_GEMM_PROTO( sb, bfloat16, float);
|
||||
GENERIC_GEMM_PROTO( sh, float16, float);
|
||||
GENERIC_GEMM_PROTO(i16, int16_t, int32_t);
|
||||
GENERIC_GEMM_PROTO( i8, int8_t, int32_t);
|
||||
GENERIC_GEMM_PROTO( i4, nibbles, int32_t);
|
||||
// 4 bit
|
||||
BIT4_PACK_A(i4, nibbles);
|
||||
BIT4_PACK_B(i4, nibbles);
|
||||
|
||||
Reference in New Issue
Block a user