Fix bug in DZGEMM

Details:
- In case of dzgemm, if the microkernel prefers column output,
  We will induce a transposition and perform C += A*B
  where A (formerly B) becomes complex and B(formerly A)
  becomes real. Hence attach complex alpha object to A instead of B.
- This commit reverts all the changes made by
  d62f12a18a and
  0b81f53074 as they are causing
  failures in make checkblis-md.

AMD-Internal: [CPUPL-2893]
Change-Id: I56b94ac136fb96003302c568ae2587142c836620
This commit is contained in:
Meghana Vankadari
2023-01-06 12:55:45 +00:00
parent dbd0c069d4
commit fbd9dde869
3 changed files with 118 additions and 11 deletions

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -132,7 +132,14 @@ void bli_gemm_front
// Attach alpha to B, and in the process typecast alpha to the target
// datatype of the matrix (which in this case is equal to the computation
// datatype).
bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &b_local );
// In case of dzgemm, if the microkernel prefers column output,
// we will induce a transposition and perform C+= A*B
// where A( formerly B) is complex. Hence attach alpha to A.
if ( bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ))
bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &a_local );
else
bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &b_local );
// Attach beta to C, and in the process typecast beta to the target
// datatype of the matrix (which in this case is equal to the storage

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc.
Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -156,20 +156,92 @@ mddm_t bli_gemm_md_ccr
cntx_t** cntx
)
{
mddm_t doms;
// We assume that the requested computation domain is complex.
//dom_t dom_comp_in = bli_obj_comp_domain( c );
//dom_t dom_comp_in = BLIS_COMPLEX;
// For ccr, the computation (ukernel) will be real, but the execution
// will appear complex to other parts of the implementation.
doms.comp = BLIS_REAL;
doms.exec = BLIS_COMPLEX;
// Here we construct the computation datatype, which for the ccr case
// is equal to the real projection of the execution datatype, and use
// that computation datatype to query the corresponding ukernel output
// preference.
const num_t dt = BLIS_REAL | bli_obj_comp_prec( c );
const bool row_pref
= bli_cntx_l3_nat_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, *cntx );
// We can only perform this case of mixed-domain gemm, C += A*B where
// B is real, if the microkernel prefers column output. If it prefers
// row output, we must induce a transposition and perform C += A*B
// where A (formerly B) is real.
if ( row_pref )
{
bli_obj_swap( a, b );
bli_obj_induce_trans( a );
bli_obj_induce_trans( b );
bli_obj_induce_trans( c );
return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx );
}
// Create a local copy of the context and then prepare to use this
// context instead of the one passed in.
*cntx_local = **cntx;
*cntx = cntx_local;
//we must induce a transposition and perform C += A*B
// where A (formerly B) is real.
bli_obj_swap( a, b );
// Copy the real domain blocksizes into the slots of their complex
// counterparts.
blksz_t* blksz_mr = bli_cntx_get_blksz( BLIS_MR, *cntx );
blksz_t* blksz_nr = bli_cntx_get_blksz( BLIS_NR, *cntx );
blksz_t* blksz_mc = bli_cntx_get_blksz( BLIS_MC, *cntx );
blksz_t* blksz_nc = bli_cntx_get_blksz( BLIS_NC, *cntx );
blksz_t* blksz_kc = bli_cntx_get_blksz( BLIS_KC, *cntx );
bli_obj_induce_trans( a );
bli_obj_induce_trans( b );
bli_obj_induce_trans( c );
bli_blksz_copy_dt( BLIS_FLOAT, blksz_mr, BLIS_SCOMPLEX, blksz_mr );
bli_blksz_copy_dt( BLIS_DOUBLE, blksz_mr, BLIS_DCOMPLEX, blksz_mr );
return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx );
bli_blksz_copy_dt( BLIS_FLOAT, blksz_nr, BLIS_SCOMPLEX, blksz_nr );
bli_blksz_copy_dt( BLIS_DOUBLE, blksz_nr, BLIS_DCOMPLEX, blksz_nr );
bli_blksz_copy_dt( BLIS_FLOAT, blksz_mc, BLIS_SCOMPLEX, blksz_mc );
bli_blksz_copy_dt( BLIS_DOUBLE, blksz_mc, BLIS_DCOMPLEX, blksz_mc );
bli_blksz_copy_dt( BLIS_FLOAT, blksz_nc, BLIS_SCOMPLEX, blksz_nc );
bli_blksz_copy_dt( BLIS_DOUBLE, blksz_nc, BLIS_DCOMPLEX, blksz_nc );
bli_blksz_copy_dt( BLIS_FLOAT, blksz_kc, BLIS_SCOMPLEX, blksz_kc );
bli_blksz_copy_dt( BLIS_DOUBLE, blksz_kc, BLIS_DCOMPLEX, blksz_kc );
// Halve both the real and complex MR's (which are both real MR's).
bli_blksz_scale_def_max( 1, 2, BLIS_FLOAT, blksz_mr );
bli_blksz_scale_def_max( 1, 2, BLIS_DOUBLE, blksz_mr );
bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mr );
bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mr );
// Halve both the real and complex MC's (which are both real MC's).
bli_blksz_scale_def_max( 1, 2, BLIS_FLOAT, blksz_mc );
bli_blksz_scale_def_max( 1, 2, BLIS_DOUBLE, blksz_mc );
bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mc );
bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mc );
// Use the default pack schemas in the context.
// static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx )
func_t* l3_vir_ukrs = bli_cntx_get_l3_vir_ukrs( BLIS_GEMM_UKR, *cntx );
// Rather than check which complex datatype dt_comp refers to, we set
// the mixed-domain virtual microkernel for both types.
bli_func_set_dt( bli_cgemm_md_c2r_ref, BLIS_SCOMPLEX, l3_vir_ukrs );
bli_func_set_dt( bli_zgemm_md_c2r_ref, BLIS_DCOMPLEX, l3_vir_ukrs );
// Return the computation and execution domains.
return doms;
}
// -----------------------------------------------------------------------------
@@ -196,6 +268,29 @@ mddm_t bli_gemm_md_crc
doms.comp = BLIS_REAL;
doms.exec = BLIS_COMPLEX;
// Here we construct the computation datatype, which for the crc case
// is equal to the real projection of the execution datatype, and use
// that computation datatype to query the corresponding ukernel output
// preference.
const num_t dt = BLIS_REAL | bli_obj_comp_prec( c );
const bool col_pref
= bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, *cntx );
// We can only perform this case of mixed-domain gemm, C += A*B where
// A is real, if the microkernel prefers row output. If it prefers
// column output, we must induce a transposition and perform C += A*B
// where B (formerly A) is real.
if ( col_pref )
{
bli_obj_swap( a, b );
bli_obj_induce_trans( a );
bli_obj_induce_trans( b );
bli_obj_induce_trans( c );
return bli_gemm_md_ccr( a, b, beta, c, cntx_local, cntx );
}
// Create a local copy of the context and then prepare to use this
// context instead of the one passed in.
*cntx_local = **cntx;

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -49,6 +49,11 @@ void bli_gemm_packa
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5);
obj_t a_pack;
// BY setting family id to BLIS_GEMM_MD, we indicate packing kernels
// to scale alpha while packing.
if(bli_obj_dt(c) != bli_obj_dt(b))
bli_cntl_set_family(BLIS_GEMM_MD, cntl);
// Pack matrix A according to the control tree node.
bli_l3_packm
(