mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Optimal rerouting of GEMV inputs to avoid packing
- Added conditional swapping of input matrices and their strides for GEMV, based on whether transpose is toggled specifically for the matrix, namely the B matrix when m=1 and the A matrix when n=1. - This swapping ensures that we reroute the inputs to use the alternative variant(code-path) in order to avoid packing cost for the matrix, through logical transposition. - Currently, this optimization is enabled only when no post-ops are involved. With post-ops, there is a need to update the incoming data(from the user) in some scenarios, which will be dealt with later. AMD-Internal: [CPUPL-7323] Co-authored-by: Vignesh Balasubramanian <vignbala@amd.com>
This commit is contained in:
committed by
GitHub
parent
98eeeb0ddb
commit
37f255821a
@@ -195,12 +195,26 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
mtag_b = PACK;
|
||||
}
|
||||
|
||||
// Temporary variables to store/transform the input before execution.
|
||||
dim_t m_use = m;
|
||||
dim_t n_use = n;
|
||||
dim_t k_use = k;
|
||||
const float* a_use = a;
|
||||
const float* b_use = b;
|
||||
float* c_use = c;
|
||||
inc_t rs_a_use = rs_a, cs_a_use = cs_a;
|
||||
inc_t rs_b_use = rs_b, cs_b_use = cs_b;
|
||||
inc_t rs_c_use = rs_c, cs_c_use = cs_c;
|
||||
AOCL_MEMORY_TAG mtag_a_use = mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b_use = mtag_b;
|
||||
// char order_use = order;
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed, post_op_list,
|
||||
( void* )c, ( void* )( &order ),
|
||||
( void* )c_use, ( void* )( &order ),
|
||||
m, n
|
||||
);
|
||||
|
||||
@@ -209,6 +223,98 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Induce operation transpose and/or swapped strides based on the input.
|
||||
// NOTE :
|
||||
// This logic is primarily used to decide what JIT kernels are to be
|
||||
// generated. Any logical induce that we perform(swapping strides/matrices)
|
||||
// would reflect in the DE when generating the kernel. Since the 5-loop
|
||||
// algorithm(framework) is detached from the CPP layer, it is expected
|
||||
// that the induced ordering is maintained when calling the 5-loop, which
|
||||
// would internally call the execute handler, thus mapping to the correct
|
||||
// kernel.
|
||||
|
||||
// Handling row-major storage.
|
||||
if ( is_row_major == TRUE )
|
||||
{
|
||||
// For now(with row major inputs), we enable operation transpose only
|
||||
// for GEMV(when the appropriate operand transpose is toggled). This is
|
||||
// done in order to avoid packing cost.
|
||||
// GEMV : Output is always a vector, and thus tranposing a "row-stored"
|
||||
// contiguous vector is still a contiguous vector(logically).
|
||||
// GEMM : Output is a matrix, and thus transposing a "row-major" matrix
|
||||
// would lead to a "column-major" matrix, which is not compatible with
|
||||
// the underlying kernels for now.
|
||||
|
||||
// This optimization is currently enabled only when post-ops are
|
||||
// disabled.
|
||||
|
||||
// For GEMV_M1
|
||||
if ( post_op_list[0].op_code == POST_OPS_DISABLE )
|
||||
{
|
||||
if ( ( m == 1 ) && bli_is_trans( blis_transb ) && ( mtag_b != REORDERED ) )
|
||||
{
|
||||
// NOTE : We will reorder the inputs such that we use the GEMV_N1
|
||||
// kernel instead of GEMV_M1, in order to avoid packing of B
|
||||
// matrix(if not already reordered). The GEMV_N1 kernels support
|
||||
// both unit/non-unit strided loads/stores for C vector. Thus, we
|
||||
// would be packing the input vector alone(if needed).
|
||||
m_use = n;
|
||||
n_use = m;
|
||||
a_use = b;
|
||||
rs_a_use = cs_b;
|
||||
cs_a_use = rs_b;
|
||||
b_use = a;
|
||||
rs_b_use = cs_a;
|
||||
cs_b_use = rs_a;
|
||||
rs_c_use = cs_c;
|
||||
cs_c_use = rs_c;
|
||||
mtag_a_use = UNPACKED;
|
||||
mtag_b_use = mtag_a;
|
||||
// order_use = 'c';
|
||||
}
|
||||
// For GEMV_N1
|
||||
// The library does not support reorder of A, thereby not needing an
|
||||
// explicit check.
|
||||
else if ( ( n == 1 ) && bli_is_trans( blis_transa ) && ( rs_c == 1 ) )
|
||||
{
|
||||
// NOTE : We will reorder the inputs such that we use the GEMV_M1
|
||||
// kernel instead of GEMV_N1, in order to avoid packing of A
|
||||
// matrix. The GEMV_M1 kernels(both classic and JIT) support only
|
||||
// unit-strided C vector(row-stored). Thus, we need an explicit
|
||||
// check for that.
|
||||
m_use = n;
|
||||
n_use = m;
|
||||
a_use = b;
|
||||
rs_a_use = cs_b;
|
||||
cs_a_use = rs_b;
|
||||
b_use = a;
|
||||
rs_b_use = cs_a;
|
||||
cs_b_use = rs_a;
|
||||
rs_c_use = cs_c;
|
||||
cs_c_use = rs_c;
|
||||
mtag_a_use = mtag_b;
|
||||
mtag_b_use = UNPACKED;
|
||||
// order_use = 'c';
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handling column-major storage.
|
||||
else
|
||||
{
|
||||
m_use = n;
|
||||
n_use = m;
|
||||
a_use = b;
|
||||
rs_a_use = rs_b;
|
||||
cs_a_use = cs_b;
|
||||
b_use = a;
|
||||
rs_b_use = rs_a;
|
||||
cs_b_use = cs_a;
|
||||
rs_c_use = rs_c;
|
||||
cs_c_use = cs_c;
|
||||
mtag_a_use = mtag_b;
|
||||
mtag_b_use = mtag_a;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
@@ -219,31 +325,14 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
|
||||
if ( is_single_thread( &rntm_g ) == TRUE )
|
||||
{
|
||||
if ( ( is_row_major == TRUE ) &&
|
||||
( is_tiny_input_f32( m, n, k, lcntx_g ) == TRUE ) )
|
||||
if ( is_tiny_input_f32( m_use, n_use, k_use, lcntx_g ) == TRUE )
|
||||
{
|
||||
lpgemm_rowvar_tiny_f32f32f32of32
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
else if ( ( is_column_major == TRUE ) &&
|
||||
( is_tiny_input_f32( n, m, k, lcntx_g ) == TRUE ) )
|
||||
{
|
||||
lpgemm_rowvar_tiny_f32f32f32of32
|
||||
(
|
||||
n, m, k,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
c, rs_c, cs_c,
|
||||
m_use, n_use, k_use,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a_use,
|
||||
b_use, rs_b_use, cs_b_use, mtag_b_use,
|
||||
c_use, rs_c_use, cs_c_use,
|
||||
alpha, beta,
|
||||
lcntx_g,
|
||||
post_op_list, F32
|
||||
@@ -257,61 +346,30 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
// The lpgemm_cntx_t argument will be NULL for f32 since it still uses
|
||||
// BLIS cntx_t internally. Its a workaround for now and will be replaced
|
||||
// with lpgemm_cntx_t eventually.
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if ( is_column_major == TRUE )
|
||||
{
|
||||
lpgemm_f32f32f32of32_openmp_thread_decorator
|
||||
(
|
||||
n, m, k,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_f32f32f32of32_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
}
|
||||
lpgemm_f32f32f32of32_openmp_thread_decorator
|
||||
(
|
||||
m_use, n_use, k_use,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a_use,
|
||||
b_use, rs_b_use, cs_b_use, mtag_b_use,
|
||||
c_use, rs_c_use, cs_c_use,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
#else
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if ( is_column_major == TRUE )
|
||||
{
|
||||
lpgemm_f32f32f32of32_thread_decorator
|
||||
(
|
||||
n, m, k,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_f32f32f32of32_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
}
|
||||
|
||||
lpgemm_f32f32f32of32_thread_decorator
|
||||
(
|
||||
m_use, n_use, k_use,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a_use,
|
||||
b_use, rs_b_use, cs_b_use, mtag_b_use,
|
||||
c_use, rs_c_use, cs_c_use,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
|
||||
Reference in New Issue
Block a user