mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Removed unnecessary pack checks in FP32 GEMV (#54)
Details: - In FP32 GEMM, when threading is disabled, rntm_pack_a and rntm_pack_b were set to true by default. This leads to perf regression for smaller sizes. Modified FP32 interface API to not overwrite the packA and packB variables in rntm structure. - In FP32 GEMV, Removed the decision making code based on mtag_A/B and should_pack_A/B for packing. Matrices will be packed only if the storage format of the matrices doesn't match the storage format required by the kernel. - Changed the control flow of checking the value of mtag to whether matrix is "reordered" or "to-be-packed" or "unpacked". checking for "reorder" first, followed by "pack". This will ensure that packing doesn't happen when the matrix is already reordered even though user forces packing by setting "BLIS_PACK_A/B" -Modified python script to generate testcases based on block sizes AMD-Internal: SWLCSG-3527
This commit is contained in:
committed by
GitHub
parent
1847a1e8c6
commit
8649cdc14b
@@ -285,10 +285,6 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
);
|
||||
}
|
||||
#else
|
||||
// Setting pack A and B by default for non open mp case.
|
||||
bli_rntm_set_pack_a( 1, &rntm_g );
|
||||
bli_rntm_set_pack_b( 1, &rntm_g );
|
||||
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if ( is_column_major == TRUE )
|
||||
{
|
||||
|
||||
@@ -124,6 +124,10 @@ typedef void (*lpgemv_a_pack_ft)
|
||||
|
||||
LPGEMV(float, float, float, f32f32f32of32)
|
||||
{
|
||||
|
||||
/* Ignoring mtag_a/b and should_pack_A/B for now .
|
||||
Matrices are packed only when the storage format is not supported by the kernel.
|
||||
*/
|
||||
const float* a_use = (float*)a;
|
||||
inc_t rs_a_use = rs_a;
|
||||
inc_t cs_a_use = cs_a;
|
||||
@@ -154,12 +158,6 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
if (c_downscale < F32) post_ops_attr.buf_downscale = c;
|
||||
else post_ops_attr.buf_downscale = NULL;
|
||||
|
||||
// Should_pack_A/B is set either by the user through env variable
|
||||
// or by the smart threading logic based on work distribution.
|
||||
// Storage format of the matrices doesn't affect should_pack_A/B.
|
||||
bool should_pack_B = bli_rntm_pack_b( rntm );
|
||||
bool should_pack_A = bli_rntm_pack_a( rntm );
|
||||
|
||||
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
||||
thrinfo_t thread_jc;
|
||||
thrinfo_t thread_ic;
|
||||
@@ -195,7 +193,7 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
packa_fp = packa_mr8_f32f32f32of32_col_major;
|
||||
#endif
|
||||
// Pack B matrix if rs_b > 1
|
||||
if( (should_pack_B == TRUE) || ( rs_b != 1 ) )
|
||||
if( rs_b != 1 )
|
||||
{
|
||||
mem_b_size_req = sizeof( float ) * k;
|
||||
|
||||
@@ -233,7 +231,7 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
post_ops_attr.post_op_c_i = ic;
|
||||
|
||||
// To-Do: pack A case needs to be handled for AVX2 case.
|
||||
if( (should_pack_A == TRUE) || ( cs_a != 1 ) )
|
||||
if( cs_a != 1 )
|
||||
{
|
||||
mem_a_size_req = sizeof(float) * mc0 * k;
|
||||
lpgemm_alloc_mem_panel
|
||||
@@ -264,11 +262,11 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
&post_ops_attr
|
||||
);
|
||||
}
|
||||
if ( ( (should_pack_A == TRUE) || ( cs_a != 1 ) ) && ( bli_mem_is_alloc( &mem_a ) ) )
|
||||
if ( ( cs_a != 1 ) && ( bli_mem_is_alloc( &mem_a ) ) )
|
||||
{
|
||||
bli_pba_release( rntm, &mem_a );
|
||||
}
|
||||
if ( ( (should_pack_B == TRUE) || ( rs_b != 1 ) ) && ( bli_mem_is_alloc( &mem_b ) ) )
|
||||
if ( ( rs_b != 1 ) && ( bli_mem_is_alloc( &mem_b ) ) )
|
||||
{
|
||||
bli_pba_release( rntm, &mem_b );
|
||||
}
|
||||
@@ -300,7 +298,7 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
thread_jc.work_id = thread->tid;
|
||||
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
|
||||
|
||||
if ( (should_pack_A == TRUE) || ( cs_a != 1 ) )
|
||||
if ( cs_a != 1 )
|
||||
{
|
||||
mem_a_size_req = sizeof( float ) * k;
|
||||
|
||||
@@ -346,7 +344,7 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
rs_b_use = NR;
|
||||
cs_b_use = 1;
|
||||
}
|
||||
else if ( (should_pack_B == TRUE) || ( mtag_b == PACK ) )
|
||||
else if ( mtag_b == PACK )
|
||||
{
|
||||
// nc0 needs to be a multiple of 16 since this gives maximum
|
||||
// vectorization. Packing B always results in buffers with width
|
||||
@@ -412,12 +410,12 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
} // jc loop
|
||||
|
||||
// Release pack buffers.
|
||||
if ( ( (should_pack_B == TRUE) || ( mtag_b == PACK ) ) && ( bli_mem_is_alloc( &mem_b ) ) )
|
||||
if ( ( mtag_b == PACK ) && ( bli_mem_is_alloc( &mem_b ) ) )
|
||||
{
|
||||
bli_pba_release( rntm, &mem_b );
|
||||
}
|
||||
|
||||
if ( ( (should_pack_A == TRUE) || ( cs_a != 1 ) ) && ( bli_mem_is_alloc( &mem_a ) ) )
|
||||
if ( ( cs_a != 1 ) && ( bli_mem_is_alloc( &mem_a ) ) )
|
||||
{
|
||||
bli_pba_release( rntm, &mem_a );
|
||||
}
|
||||
@@ -569,7 +567,20 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE );
|
||||
post_ops_attr.is_last_k = is_last_k;
|
||||
|
||||
if ( ( mtag_b == PACK ) || ( should_pack_B == TRUE ) )
|
||||
if ( mtag_b == REORDERED )
|
||||
{
|
||||
// In multi-threaded scenarios, an extra offset into a given
|
||||
// packed B panel is required, since the jc loop split can
|
||||
// result in per thread start offset inside the panel, instead
|
||||
// of panel boundaries.
|
||||
b_use = b + ( jc_cur_loop * k ) +
|
||||
( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 );
|
||||
|
||||
rs_b_use = NR;
|
||||
cs_b_use = 1;
|
||||
ps_b_use = kc0;
|
||||
}
|
||||
else if ( ( mtag_b == PACK ) || ( should_pack_B == TRUE ) )
|
||||
{
|
||||
// Pack B chunks are based on jc work id.
|
||||
dim_t jc_work_id = bli_thread_work_id( &thread_jc );
|
||||
@@ -649,19 +660,6 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
);
|
||||
b_use = pack_b_buffer_f32f32f32of32;
|
||||
}
|
||||
else if ( mtag_b == REORDERED )
|
||||
{
|
||||
// In multi-threaded scenarios, an extra offset into a given
|
||||
// packed B panel is required, since the jc loop split can
|
||||
// result in per thread start offset inside the panel, instead
|
||||
// of panel boundaries.
|
||||
b_use = b + ( jc_cur_loop * k ) +
|
||||
( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 );
|
||||
|
||||
rs_b_use = NR;
|
||||
cs_b_use = 1;
|
||||
ps_b_use = kc0;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_use = b + ( pc * rs_b ) + ( jc * cs_b );
|
||||
|
||||
@@ -266,7 +266,13 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
|
||||
cs_a_use = 1;
|
||||
}
|
||||
|
||||
if ( ( mtag_b == PACK ) )
|
||||
if ( mtag_b == REORDERED )
|
||||
{
|
||||
b_use = ( float* )b;
|
||||
rs_b_use = NR;
|
||||
cs_b_use = 1;
|
||||
}
|
||||
else if ( ( mtag_b == PACK ) )
|
||||
{
|
||||
dim_t nc0_updated = make_multiple_of_n(n, NR);
|
||||
siz_t mem_b_size_req = sizeof(float) * nc0_updated * k;
|
||||
@@ -288,12 +294,6 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
|
||||
|
||||
b_use = pack_b_buffer_f32f32f32of32;
|
||||
}
|
||||
else if ( mtag_b == REORDERED )
|
||||
{
|
||||
b_use = ( float* )b;
|
||||
rs_b_use = NR;
|
||||
cs_b_use = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_use = ( float* )b;
|
||||
@@ -388,7 +388,14 @@ LPGEMM_TINY(float,float,float,f32f32f32of32)
|
||||
// Even if the mtag_b is set to PACK, for tiny sizes its better to
|
||||
// pack only if it affects output accuracy (like column major B),
|
||||
// else ignore it.
|
||||
if ( ( mtag_b == PACK ) )
|
||||
if ( mtag_b == REORDERED )
|
||||
{
|
||||
b_use = b;
|
||||
rs_b_use = NR;
|
||||
cs_b_use = 1;
|
||||
ps_b_use = k;
|
||||
}
|
||||
else if ( ( mtag_b == PACK ) )
|
||||
{
|
||||
dim_t nc0_updated = make_multiple_of_n( n, NR );
|
||||
mem_b_size_req = sizeof( float ) * nc0_updated * k;
|
||||
@@ -410,13 +417,6 @@ LPGEMM_TINY(float,float,float,f32f32f32of32)
|
||||
|
||||
b_use = pack_b_buffer_f32f32f32of32;
|
||||
}
|
||||
else if ( mtag_b == REORDERED )
|
||||
{
|
||||
b_use = b;
|
||||
rs_b_use = NR;
|
||||
cs_b_use = 1;
|
||||
ps_b_use = k;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_use = b;
|
||||
|
||||
@@ -12,42 +12,107 @@
|
||||
#
|
||||
# To-Do: Add more testcases to cover testing of 5-loop framework
|
||||
# taking blocksizes into consideration based on API being tested.
|
||||
import os
|
||||
from enum import Enum, auto
|
||||
|
||||
class BlocksizeType(Enum):
|
||||
MC = 0
|
||||
NC = 1
|
||||
KC = 2
|
||||
MR = 3
|
||||
NR = 4
|
||||
PACKA_RS = 5
|
||||
PACKA_CS = 6
|
||||
PACKB_RS = 7
|
||||
PACKB_CS = 8
|
||||
|
||||
|
||||
# Helper function to get blocksize values
|
||||
def get_blocksize(data_type: str, blocksize_type: BlocksizeType) -> int:
|
||||
"""Get blocksize value using enum index"""
|
||||
if data_type in BLKSZ_MAP:
|
||||
return BLKSZ_MAP[data_type][blocksize_type.value]
|
||||
raise ValueError(f"Data type {data_type} not found in blocksize map")
|
||||
|
||||
|
||||
LPGEMM_BLKSZ_MAP_ZEN4 = {
|
||||
"u8s8s32": [144, 1024, 2048, 6, 64, 4, 24, 4*64, 64] ,
|
||||
"f32f32f32": [192, 8064, 512, 6, 64, 1, 6, 64, 1] ,
|
||||
"bf16bf16f32": [144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2] ,
|
||||
"bf16s4f32": [144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2] ,
|
||||
"s8s8s32": [144, 1024, 2048, 6, 64, 4, 24, 4*64, 64]
|
||||
}
|
||||
|
||||
|
||||
LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN = {
|
||||
"f32f32f32": [144, 8064, 512, 6, 64, 1, 6, 64, 1]
|
||||
}
|
||||
|
||||
LPGEMM_BLKSZ_MAP_ZEN = {
|
||||
"u8s8s32": [144, 1024, 2048, 6, 64, 4, 24, 4*64, 64] ,
|
||||
"f32f32f32": [144, 8064, 512, 6, 16, 1, 6, 16, 1] ,
|
||||
"bf16bf16f32": [144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2] ,
|
||||
"bf16s4f32": [144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2] ,
|
||||
"s8s8s32": [144, 1024, 2048, 6, 64, 4, 24, 4*64, 64]
|
||||
}
|
||||
|
||||
|
||||
# Get the environment variable for testcase generation
|
||||
GENERATE_TESTCASES_FOR = os.getenv('GENERATE_TESTCASES_FOR', 'zen4').lower()
|
||||
|
||||
# Select the appropriate blocksize map based on environment variable
|
||||
if GENERATE_TESTCASES_FOR == 'zen4':
|
||||
BLKSZ_MAP = LPGEMM_BLKSZ_MAP_ZEN4
|
||||
elif GENERATE_TESTCASES_FOR == 'zen':
|
||||
BLKSZ_MAP = LPGEMM_BLKSZ_MAP_ZEN
|
||||
elif GENERATE_TESTCASES_FOR == 'zen4_to_zen':
|
||||
BLKSZ_MAP = LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN
|
||||
else:
|
||||
raise ValueError(f"Invalid value for GENERATE_TESTCASES_FOR: {GENERATE_TESTCASES_FOR}. Must be one of: zen4, zen, zen4_to_zen")
|
||||
|
||||
ops = {
|
||||
#In,acc type: [out_types]
|
||||
"f32f32f32" : ["f32"],
|
||||
"bf16bf16f32" : ["f32", "bf16"],
|
||||
#"bf16bf16f32" : ["f32", "bf16"],
|
||||
#"bf16s4f32" : ["f32", "bf16"],
|
||||
"s8s8s32" : ["s32", "u8", "s8", "bf16", "f32"],
|
||||
"u8s8s32" : ["s32", "u8", "s8", "bf16", "f32"]
|
||||
# "s8s8s32" : ["s32", "u8", "s8", "bf16", "f32"],
|
||||
# "u8s8s32" : ["s32", "u8", "s8", "bf16", "f32"]
|
||||
}
|
||||
post_ops = ["none", "bias", "relu", "prelu", "clip", "matrix_add", "matrix_mul",
|
||||
post_ops = ["none"]#["none", "bias", "relu", "prelu", "clip", "matrix_add", "matrix_mul",
|
||||
# "swish", "gelu_tanh", "gelu_erf", "tanh", "sigmoid",
|
||||
"scale=scalar,zp=scalar", "scale=vector,zp=scalar", "scale=scalar,zp=vector","scale=vector,zp=vector"]
|
||||
#"scale=scalar,zp=scalar", "scale=vector,zp=scalar", "scale=scalar,zp=vector","scale=vector,zp=vector"]
|
||||
|
||||
|
||||
ofile = open("accuracy_test_data_lpgemm.txt", "w")
|
||||
import sys
|
||||
|
||||
for stor in ["r"]:
|
||||
packb_list = ["n", "p"] if stor == "c" else ["n", "p", "r"]
|
||||
for transa in ["n", "t"]:
|
||||
for transb in ["n", "t"]:
|
||||
for packa in ["n"]:
|
||||
for packb in packb_list:
|
||||
for m in range( 6, 0, -1):
|
||||
for n in [64, 48, 32, 16, 10, 1]:
|
||||
for k in [32, 256, 1024]:
|
||||
if( stor == "c" ):
|
||||
stride_a = m if transa == "n" else k
|
||||
stride_b = k if transb == "n" else n
|
||||
stride_c = m
|
||||
else:
|
||||
stride_a = k if transa == "n" else m
|
||||
stride_b = n if transb == "n" else k
|
||||
stride_c = n
|
||||
dims = " ".join([str(m), str(n), str(k), str(stride_a), str(stride_b), str(stride_c)])
|
||||
chars = " ".join([stor, transa, transb, packa, packb])
|
||||
for inputtypes in ops:
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python bench_data_gen_lpgemm.py <output_filename>")
|
||||
sys.exit(1)
|
||||
|
||||
output_filename = sys.argv[1]
|
||||
ofile = open(output_filename, "w")
|
||||
|
||||
|
||||
for inputtypes in ops:
|
||||
for stor in ["r", "c"]:
|
||||
packb_list = ["n", "p"] if stor == "c" else ["n", "p", "r"]
|
||||
for transa in ["n", "t"]:
|
||||
for transb in ["n", "t"]:
|
||||
for packa in ["n"]:
|
||||
for packb in packb_list:
|
||||
for m in range(1, (get_blocksize(inputtypes, BlocksizeType.MR) * 2)+ 1):
|
||||
for n in range(1, get_blocksize(inputtypes, BlocksizeType.NR) * 2 + 1):
|
||||
for k in [32, 256, get_blocksize(inputtypes, BlocksizeType.KC)]:
|
||||
if( stor == "c" ):
|
||||
stride_a = m if transa == "n" else k
|
||||
stride_b = k if transb == "n" else n
|
||||
stride_c = m
|
||||
else:
|
||||
stride_a = k if transa == "n" else m
|
||||
stride_b = n if transb == "n" else k
|
||||
stride_c = n
|
||||
dims = " ".join([str(m), str(n), str(k), str(stride_a), str(stride_b), str(stride_c)])
|
||||
chars = " ".join([stor, transa, transb, packa, packb])
|
||||
for output_type in ops[inputtypes]:
|
||||
op = inputtypes + "o" + output_type
|
||||
for post_op in post_ops:
|
||||
@@ -57,4 +122,37 @@ for stor in ["r"]:
|
||||
else:
|
||||
post_op += "=" + output_type
|
||||
ofile.write(chars + " " + dims + " " + op + ":" + post_op + "\n")
|
||||
|
||||
#5 loop testcases
|
||||
for inputtypes in ops:
|
||||
for stor in ["r", "c"]:
|
||||
packb_list = ["n", "p"] if stor == "c" else ["n", "p", "r"]
|
||||
for transa in ["n", "t"]:
|
||||
for transb in ["n", "t"]:
|
||||
for packa in ["n"]:
|
||||
for packb in packb_list:
|
||||
for m in [ get_blocksize(inputtypes, BlocksizeType.MC) * 2]:
|
||||
for n in [get_blocksize(inputtypes, BlocksizeType.NC) * 2]:
|
||||
for k in [get_blocksize(inputtypes, BlocksizeType.KC) * 2]:
|
||||
if( stor == "c" ):
|
||||
stride_a = m if transa == "n" else k
|
||||
stride_b = k if transb == "n" else n
|
||||
stride_c = m
|
||||
else:
|
||||
stride_a = k if transa == "n" else m
|
||||
stride_b = n if transb == "n" else k
|
||||
stride_c = n
|
||||
dims = " ".join([str(m), str(n), str(k), str(stride_a), str(stride_b), str(stride_c)])
|
||||
chars = " ".join([stor, transa, transb, packa, packb])
|
||||
for output_type in ops[inputtypes]:
|
||||
op = inputtypes + "o" + output_type
|
||||
for post_op in post_ops:
|
||||
if post_op == "bias" or post_op == "matrix_add" or post_op == "matrix_mul":
|
||||
if( output_type == "u8"):
|
||||
post_op += "=" + "na"
|
||||
else:
|
||||
post_op += "=" + output_type
|
||||
ofile.write(chars + " " + dims + " " + op + ":" + post_op + "\n")
|
||||
|
||||
|
||||
ofile.close()
|
||||
|
||||
Reference in New Issue
Block a user