Optimal rerouting of GEMV inputs to avoid packing

- Added conditional swapping of input matrices and their
  strides for GEMV, based on whether transpose is toggled
  specifically for the matrix, namely the B matrix when m=1
  and the A matrix when n=1.

- This swapping ensures that we reroute the inputs to use the
  alternative variant(code-path) in order to avoid packing cost
  for the matrix, through logical transposition.

- Currently, this optimization is enabled only when no post-ops
  are involved. With post-ops, there is a need to update the
  incoming data(from the user) in some scenarios, which will be
  dealt with later.

AMD-Internal: [CPUPL-7323]

Co-authored-by: Vignesh Balasubramanian <vignbala@amd.com>
This commit is contained in:
Balasubramanian, Vignesh
2025-09-03 09:15:59 +05:30
committed by GitHub
parent 98eeeb0ddb
commit 37f255821a

View File

@@ -195,12 +195,26 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
mtag_b = PACK;
}
// Temporary variables to store/transform the input before execution.
dim_t m_use = m;
dim_t n_use = n;
dim_t k_use = k;
const float* a_use = a;
const float* b_use = b;
float* c_use = c;
inc_t rs_a_use = rs_a, cs_a_use = cs_a;
inc_t rs_b_use = rs_b, cs_b_use = cs_b;
inc_t rs_c_use = rs_c, cs_c_use = cs_c;
AOCL_MEMORY_TAG mtag_a_use = mtag_a;
AOCL_MEMORY_TAG mtag_b_use = mtag_b;
// char order_use = order;
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed, post_op_list,
( void* )c, ( void* )( &order ),
( void* )c_use, ( void* )( &order ),
m, n
);
@@ -209,6 +223,98 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
goto err_hndl;
}
// Induce operation transpose and/or swapped strides based on the input.
// NOTE :
// This logic is primarily used to decide what JIT kernels are to be
// generated. Any logical induce that we perform(swapping strides/matrices)
// would reflect in the DE when generating the kernel. Since the 5-loop
// algorithm(framework) is detached from the CPP layer, it is expected
// that the induced ordering is maintained when calling the 5-loop, which
// would internally call the execute handler, thus mapping to the correct
// kernel.
// Handling row-major storage.
if ( is_row_major == TRUE )
{
// For now(with row major inputs), we enable operation transpose only
// for GEMV(when the appropriate operand transpose is toggled). This is
// done in order to avoid packing cost.
// GEMV : Output is always a vector, and thus tranposing a "row-stored"
// contiguous vector is still a contiguous vector(logically).
// GEMM : Output is a matrix, and thus transposing a "row-major" matrix
// would lead to a "column-major" matrix, which is not compatible with
// the underlying kernels for now.
// This optimization is currently enabled only when post-ops are
// disabled.
// For GEMV_M1
if ( post_op_list[0].op_code == POST_OPS_DISABLE )
{
if ( ( m == 1 ) && bli_is_trans( blis_transb ) && ( mtag_b != REORDERED ) )
{
// NOTE : We will reorder the inputs such that we use the GEMV_N1
// kernel instead of GEMV_M1, in order to avoid packing of B
// matrix(if not already reordered). The GEMV_N1 kernels support
// both unit/non-unit strided loads/stores for C vector. Thus, we
// would be packing the input vector alone(if needed).
m_use = n;
n_use = m;
a_use = b;
rs_a_use = cs_b;
cs_a_use = rs_b;
b_use = a;
rs_b_use = cs_a;
cs_b_use = rs_a;
rs_c_use = cs_c;
cs_c_use = rs_c;
mtag_a_use = UNPACKED;
mtag_b_use = mtag_a;
// order_use = 'c';
}
// For GEMV_N1
// The library does not support reorder of A, thereby not needing an
// explicit check.
else if ( ( n == 1 ) && bli_is_trans( blis_transa ) && ( rs_c == 1 ) )
{
// NOTE : We will reorder the inputs such that we use the GEMV_M1
// kernel instead of GEMV_N1, in order to avoid packing of A
// matrix. The GEMV_M1 kernels(both classic and JIT) support only
// unit-strided C vector(row-stored). Thus, we need an explicit
// check for that.
m_use = n;
n_use = m;
a_use = b;
rs_a_use = cs_b;
cs_a_use = rs_b;
b_use = a;
rs_b_use = cs_a;
cs_b_use = rs_a;
rs_c_use = cs_c;
cs_c_use = rs_c;
mtag_a_use = mtag_b;
mtag_b_use = UNPACKED;
// order_use = 'c';
}
}
}
// Handling column-major storage.
else
{
m_use = n;
n_use = m;
a_use = b;
rs_a_use = rs_b;
cs_a_use = cs_b;
b_use = a;
rs_b_use = rs_a;
cs_b_use = cs_a;
rs_c_use = rs_c;
cs_c_use = cs_c;
mtag_a_use = mtag_b;
mtag_b_use = mtag_a;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
@@ -219,31 +325,14 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
if ( is_single_thread( &rntm_g ) == TRUE )
{
if ( ( is_row_major == TRUE ) &&
( is_tiny_input_f32( m, n, k, lcntx_g ) == TRUE ) )
if ( is_tiny_input_f32( m_use, n_use, k_use, lcntx_g ) == TRUE )
{
lpgemm_rowvar_tiny_f32f32f32of32
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
lcntx_g,
post_op_list, F32
);
return;
}
else if ( ( is_column_major == TRUE ) &&
( is_tiny_input_f32( n, m, k, lcntx_g ) == TRUE ) )
{
lpgemm_rowvar_tiny_f32f32f32of32
(
n, m, k,
b, rs_b, cs_b, mtag_b,
a, rs_a, cs_a, mtag_a,
c, rs_c, cs_c,
m_use, n_use, k_use,
a_use, rs_a_use, cs_a_use, mtag_a_use,
b_use, rs_b_use, cs_b_use, mtag_b_use,
c_use, rs_c_use, cs_c_use,
alpha, beta,
lcntx_g,
post_op_list, F32
@@ -257,61 +346,30 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
// The lpgemm_cntx_t argument will be NULL for f32 since it still uses
// BLIS cntx_t internally. Its a workaround for now and will be replaced
// with lpgemm_cntx_t eventually.
// Swapping inputs to induce row major computation for column major inputs.
if ( is_column_major == TRUE )
{
lpgemm_f32f32f32of32_openmp_thread_decorator
(
n, m, k,
b, rs_b, cs_b, mtag_b,
a, rs_a, cs_a, mtag_a,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
}
else
{
lpgemm_f32f32f32of32_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
}
lpgemm_f32f32f32of32_openmp_thread_decorator
(
m_use, n_use, k_use,
a_use, rs_a_use, cs_a_use, mtag_a_use,
b_use, rs_b_use, cs_b_use, mtag_b_use,
c_use, rs_c_use, cs_c_use,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#else
// Swapping inputs to induce row major computation for column major inputs.
if ( is_column_major == TRUE )
{
lpgemm_f32f32f32of32_thread_decorator
(
n, m, k,
b, rs_b, cs_b, mtag_b,
a, rs_a, cs_a, mtag_a,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
}
else
{
lpgemm_f32f32f32of32_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
}
lpgemm_f32f32f32of32_thread_decorator
(
m_use, n_use, k_use,
a_use, rs_a_use, cs_a_use, mtag_a_use,
b_use, rs_b_use, cs_b_use, mtag_b_use,
c_use, rs_c_use, cs_c_use,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#endif
err_hndl:;