diff --git a/frame/3/gemm/bli_gemm_md.c b/frame/3/gemm/bli_gemm_md.c index 0f82b15f3..c9450a26c 100644 --- a/frame/3/gemm/bli_gemm_md.c +++ b/frame/3/gemm/bli_gemm_md.c @@ -172,14 +172,12 @@ mddm_t bli_gemm_md_ccr // that computation datatype to query the corresponding ukernel output // preference. const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - const bool row_pref - = bli_cntx_l3_nat_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, *cntx ); // We can only perform this case of mixed-domain gemm, C += A*B where // B is real, if the microkernel prefers column output. If it prefers // row output, we must induce a transposition and perform C += A*B // where A (formerly B) is real. - if ( row_pref ) + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of_md( c, dt, BLIS_GEMM_UKR, *cntx ) ) { bli_obj_swap( a, b ); @@ -273,14 +271,12 @@ mddm_t bli_gemm_md_crc // that computation datatype to query the corresponding ukernel output // preference. const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - const bool col_pref - = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, *cntx ); // We can only perform this case of mixed-domain gemm, C += A*B where // A is real, if the microkernel prefers row output. If it prefers // column output, we must induce a transposition and perform C += A*B // where B (formerly A) is real. - if ( col_pref ) + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of_md( c, dt, BLIS_GEMM_UKR, *cntx ) ) { bli_obj_swap( a, b ); diff --git a/frame/base/bli_cntx.h b/frame/base/bli_cntx.h index 3715d70c9..fcef4738f 100644 --- a/frame/base/bli_cntx.h +++ b/frame/base/bli_cntx.h @@ -601,6 +601,27 @@ BLIS_INLINE bool bli_cntx_l3_vir_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t uk !bli_cntx_l3_vir_ukr_prefers_storage_of( obj, ukr_id, cntx ); } +BLIS_INLINE bool bli_cntx_l3_vir_ukr_prefers_storage_of_md( obj_t* obj, num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +{ + // we use the computation datatype, which may differ from the + // storage datatype of C + const bool ukr_prefers_rows + = bli_cntx_l3_vir_ukr_prefers_rows_dt( dt, ukr_id, cntx ); + const bool ukr_prefers_cols + = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, ukr_id, cntx ); + bool r_val = FALSE; + + if ( bli_obj_is_row_stored( obj ) && ukr_prefers_rows ) r_val = TRUE; + else if ( bli_obj_is_col_stored( obj ) && ukr_prefers_cols ) r_val = TRUE; + return r_val; +} + +BLIS_INLINE bool bli_cntx_l3_vir_ukr_dislikes_storage_of_md( obj_t* obj, num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +{ + return ( bool ) + !bli_cntx_l3_vir_ukr_prefers_storage_of_md( obj, dt, ukr_id, cntx ); +} + // ----------------------------------------------------------------------------- BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx ) {