From fbd9dde8699c2a2a9c2b580e19d233e394165b5b Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Fri, 6 Jan 2023 12:55:45 +0000 Subject: [PATCH] Fix bug in DZGEMM Details: - In case of dzgemm, if the microkernel prefers column output, We will induce a transposition and perform C += A*B where A (formerly B) becomes complex and B(formerly A) becomes real. Hence attach complex alpha object to A instead of B. - This commit reverts all the changes made by d62f12a18a9e8c56ac9a64253c266017e4502a8e and 0b81f530746f856f7cc8f7e5be9684b016d712db as they are causing failures in make checkblis-md. AMD-Internal: [CPUPL-2893] Change-Id: I56b94ac136fb96003302c568ae2587142c836620 --- frame/3/gemm/bli_gemm_front_amd.c | 11 ++- frame/3/gemm/bli_gemm_md.c | 111 +++++++++++++++++++++++++++--- frame/3/gemm/bli_gemm_packab.c | 7 +- 3 files changed, 118 insertions(+), 11 deletions(-) diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c index b15d906dd..b64baf000 100644 --- a/frame/3/gemm/bli_gemm_front_amd.c +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -132,7 +132,14 @@ void bli_gemm_front // Attach alpha to B, and in the process typecast alpha to the target // datatype of the matrix (which in this case is equal to the computation // datatype). - bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &b_local ); + + // In case of dzgemm, if the microkernel prefers column output, + // we will induce a transposition and perform C+= A*B + // where A( formerly B) is complex. Hence attach alpha to A. + if ( bli_gemm_md_is_ccr( &a_local, &b_local, &c_local )) + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &a_local ); + else + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &b_local ); // Attach beta to C, and in the process typecast beta to the target // datatype of the matrix (which in this case is equal to the storage diff --git a/frame/3/gemm/bli_gemm_md.c b/frame/3/gemm/bli_gemm_md.c index 68298c71c..158069f63 100644 --- a/frame/3/gemm/bli_gemm_md.c +++ b/frame/3/gemm/bli_gemm_md.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -156,20 +156,92 @@ mddm_t bli_gemm_md_ccr cntx_t** cntx ) { + mddm_t doms; + + // We assume that the requested computation domain is complex. + //dom_t dom_comp_in = bli_obj_comp_domain( c ); + //dom_t dom_comp_in = BLIS_COMPLEX; + + // For ccr, the computation (ukernel) will be real, but the execution + // will appear complex to other parts of the implementation. + doms.comp = BLIS_REAL; + doms.exec = BLIS_COMPLEX; + + // Here we construct the computation datatype, which for the ccr case + // is equal to the real projection of the execution datatype, and use + // that computation datatype to query the corresponding ukernel output + // preference. + const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); + const bool row_pref + = bli_cntx_l3_nat_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, *cntx ); + + // We can only perform this case of mixed-domain gemm, C += A*B where + // B is real, if the microkernel prefers column output. If it prefers + // row output, we must induce a transposition and perform C += A*B + // where A (formerly B) is real. + if ( row_pref ) + { + bli_obj_swap( a, b ); + + bli_obj_induce_trans( a ); + bli_obj_induce_trans( b ); + bli_obj_induce_trans( c ); + + return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx ); + } + // Create a local copy of the context and then prepare to use this // context instead of the one passed in. *cntx_local = **cntx; *cntx = cntx_local; - //we must induce a transposition and perform C += A*B - // where A (formerly B) is real. - bli_obj_swap( a, b ); + // Copy the real domain blocksizes into the slots of their complex + // counterparts. + blksz_t* blksz_mr = bli_cntx_get_blksz( BLIS_MR, *cntx ); + blksz_t* blksz_nr = bli_cntx_get_blksz( BLIS_NR, *cntx ); + blksz_t* blksz_mc = bli_cntx_get_blksz( BLIS_MC, *cntx ); + blksz_t* blksz_nc = bli_cntx_get_blksz( BLIS_NC, *cntx ); + blksz_t* blksz_kc = bli_cntx_get_blksz( BLIS_KC, *cntx ); - bli_obj_induce_trans( a ); - bli_obj_induce_trans( b ); - bli_obj_induce_trans( c ); + bli_blksz_copy_dt( BLIS_FLOAT, blksz_mr, BLIS_SCOMPLEX, blksz_mr ); + bli_blksz_copy_dt( BLIS_DOUBLE, blksz_mr, BLIS_DCOMPLEX, blksz_mr ); - return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx ); + bli_blksz_copy_dt( BLIS_FLOAT, blksz_nr, BLIS_SCOMPLEX, blksz_nr ); + bli_blksz_copy_dt( BLIS_DOUBLE, blksz_nr, BLIS_DCOMPLEX, blksz_nr ); + + bli_blksz_copy_dt( BLIS_FLOAT, blksz_mc, BLIS_SCOMPLEX, blksz_mc ); + bli_blksz_copy_dt( BLIS_DOUBLE, blksz_mc, BLIS_DCOMPLEX, blksz_mc ); + + bli_blksz_copy_dt( BLIS_FLOAT, blksz_nc, BLIS_SCOMPLEX, blksz_nc ); + bli_blksz_copy_dt( BLIS_DOUBLE, blksz_nc, BLIS_DCOMPLEX, blksz_nc ); + + bli_blksz_copy_dt( BLIS_FLOAT, blksz_kc, BLIS_SCOMPLEX, blksz_kc ); + bli_blksz_copy_dt( BLIS_DOUBLE, blksz_kc, BLIS_DCOMPLEX, blksz_kc ); + + // Halve both the real and complex MR's (which are both real MR's). + bli_blksz_scale_def_max( 1, 2, BLIS_FLOAT, blksz_mr ); + bli_blksz_scale_def_max( 1, 2, BLIS_DOUBLE, blksz_mr ); + bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mr ); + bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mr ); + + // Halve both the real and complex MC's (which are both real MC's). + bli_blksz_scale_def_max( 1, 2, BLIS_FLOAT, blksz_mc ); + bli_blksz_scale_def_max( 1, 2, BLIS_DOUBLE, blksz_mc ); + bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mc ); + bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mc ); + + // Use the default pack schemas in the context. + + // static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) + func_t* l3_vir_ukrs = bli_cntx_get_l3_vir_ukrs( BLIS_GEMM_UKR, *cntx ); + + // Rather than check which complex datatype dt_comp refers to, we set + // the mixed-domain virtual microkernel for both types. + bli_func_set_dt( bli_cgemm_md_c2r_ref, BLIS_SCOMPLEX, l3_vir_ukrs ); + bli_func_set_dt( bli_zgemm_md_c2r_ref, BLIS_DCOMPLEX, l3_vir_ukrs ); + + // Return the computation and execution domains. + return doms; } // ----------------------------------------------------------------------------- @@ -196,6 +268,29 @@ mddm_t bli_gemm_md_crc doms.comp = BLIS_REAL; doms.exec = BLIS_COMPLEX; + // Here we construct the computation datatype, which for the crc case + // is equal to the real projection of the execution datatype, and use + // that computation datatype to query the corresponding ukernel output + // preference. + const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); + const bool col_pref + = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, *cntx ); + + // We can only perform this case of mixed-domain gemm, C += A*B where + // A is real, if the microkernel prefers row output. If it prefers + // column output, we must induce a transposition and perform C += A*B + // where B (formerly A) is real. + if ( col_pref ) + { + bli_obj_swap( a, b ); + + bli_obj_induce_trans( a ); + bli_obj_induce_trans( b ); + bli_obj_induce_trans( c ); + + return bli_gemm_md_ccr( a, b, beta, c, cntx_local, cntx ); + } + // Create a local copy of the context and then prepare to use this // context instead of the one passed in. *cntx_local = **cntx; diff --git a/frame/3/gemm/bli_gemm_packab.c b/frame/3/gemm/bli_gemm_packab.c index 682872554..098206df7 100644 --- a/frame/3/gemm/bli_gemm_packab.c +++ b/frame/3/gemm/bli_gemm_packab.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,6 +49,11 @@ void bli_gemm_packa AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5); obj_t a_pack; + // BY setting family id to BLIS_GEMM_MD, we indicate packing kernels + // to scale alpha while packing. + if(bli_obj_dt(c) != bli_obj_dt(b)) + bli_cntl_set_family(BLIS_GEMM_MD, cntl); + // Pack matrix A according to the control tree node. bli_l3_packm (