diff --git a/testsuite/src/test_gemmt.c b/testsuite/src/test_gemmt.c index 3b7b08748..af61eff6e 100644 --- a/testsuite/src/test_gemmt.c +++ b/testsuite/src/test_gemmt.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,303 +36,344 @@ #include "blis.h" #include "test_libblis.h" +#define PRINT 0 // Static variables. -static char* op_str = "gemmt"; -static char* o_types = "mmm"; // a b c -static char* p_types = "uhh"; // uploc transa transb -static thresh_t thresh[BLIS_NUM_FP_TYPES] = { { 1e-04, 1e-05 }, // warn, pass for s - { 1e-04, 1e-05 }, // warn, pass for c - { 1e-13, 1e-14 }, // warn, pass for d - { 1e-13, 1e-14 } }; // warn, pass for z +static char *op_str = "gemmt"; +static char *o_types = "mmm"; // a b c +static char *p_types = "uhh"; // uploc transa transb +static thresh_t thresh[BLIS_NUM_FP_TYPES] = {{1e-04, 1e-05}, // warn, pass for s + {1e-04, 1e-05}, // warn, pass for c + {1e-13, 1e-14}, // warn, pass for d + {1e-13, 1e-14}}; // warn, pass for z // Local prototypes. -void libblis_test_gemmt_deps - ( - thread_data_t* tdata, - test_params_t* params, - test_op_t* op - ); +void libblis_test_gemmt_deps( + thread_data_t *tdata, + test_params_t *params, + test_op_t *op); -void libblis_test_gemmt_experiment - ( - test_params_t* params, - test_op_t* op, - iface_t iface, - char* dc_str, - char* pc_str, - char* sc_str, - unsigned int p_cur, - double* perf, - double* resid - ); +void libblis_test_gemmt_experiment( + test_params_t *params, + test_op_t *op, + iface_t iface, + char *dc_str, + char *pc_str, + char *sc_str, + unsigned int p_cur, + double *perf, + double *resid); -void libblis_test_gemmt_impl - ( - iface_t iface, - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c - ); +void libblis_test_gemmt_impl( + iface_t iface, + obj_t *alpha, + obj_t *a, + obj_t *b, + obj_t *beta, + obj_t *c); -void libblis_test_gemmt_check - ( - test_params_t* params, - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - obj_t* c_orig, - double* resid - ); +void libblis_test_gemmt_check( + test_params_t *params, + obj_t *alpha, + obj_t *a, + obj_t *b, + obj_t *beta, + obj_t *c, + obj_t *c_orig, + double *resid); - - -void libblis_test_gemmt_deps - ( - thread_data_t* tdata, - test_params_t* params, - test_op_t* op - ) +void libblis_test_gemmt_deps( + thread_data_t *tdata, + test_params_t *params, + test_op_t *op) { - libblis_test_randv( tdata, params, &(op->ops->randv) ); - libblis_test_randm( tdata, params, &(op->ops->randm) ); - libblis_test_setv( tdata, params, &(op->ops->setv) ); - libblis_test_normfv( tdata, params, &(op->ops->normfv) ); - libblis_test_subv( tdata, params, &(op->ops->subv) ); - libblis_test_scalv( tdata, params, &(op->ops->scalv) ); - libblis_test_copym( tdata, params, &(op->ops->copym) ); - libblis_test_scalm( tdata, params, &(op->ops->scalm) ); - libblis_test_gemv( tdata, params, &(op->ops->gemv) ); - libblis_test_gemm( tdata, params, &(op->ops->gemm) ); + libblis_test_randv(tdata, params, &(op->ops->randv)); + libblis_test_randm(tdata, params, &(op->ops->randm)); + libblis_test_normfv(tdata, params, &(op->ops->normfv)); + libblis_test_subv(tdata, params, &(op->ops->subv)); + libblis_test_copym(tdata, params, &(op->ops->copym)); + libblis_test_gemv(tdata, params, &(op->ops->gemv)); + libblis_test_addm(tdata, params, &(op->ops->addm)); } - - -void libblis_test_gemmt - ( - thread_data_t* tdata, - test_params_t* params, - test_op_t* op - ) +void libblis_test_gemmt( + thread_data_t *tdata, + test_params_t *params, + test_op_t *op) { // Return early if this test has already been done. - if ( libblis_test_op_is_done( op ) ) return; + if (libblis_test_op_is_done(op)) + return; // Return early if operation is disabled. - if ( libblis_test_op_is_disabled( op ) || - libblis_test_l3_is_disabled( op ) ) return; + if (libblis_test_op_is_disabled(op) || + libblis_test_l3_is_disabled(op)) + return; // Call dependencies first. - if ( TRUE ) libblis_test_gemmt_deps( tdata, params, op ); + if (TRUE) + libblis_test_gemmt_deps(tdata, params, op); // Execute the test driver for each implementation requested. //if ( op->front_seq == ENABLE ) { - libblis_test_op_driver( tdata, - params, - op, - BLIS_TEST_SEQ_FRONT_END, - op_str, - p_types, - o_types, - thresh, - libblis_test_gemmt_experiment ); + libblis_test_op_driver(tdata, + params, + op, + BLIS_TEST_SEQ_FRONT_END, + op_str, + p_types, + o_types, + thresh, + libblis_test_gemmt_experiment); } } - - -void libblis_test_gemmt_experiment - ( - test_params_t* params, - test_op_t* op, - iface_t iface, - char* dc_str, - char* pc_str, - char* sc_str, - unsigned int p_cur, - double* perf, - double* resid - ) +void libblis_test_gemmt_experiment( + test_params_t *params, + test_op_t *op, + iface_t iface, + char *dc_str, + char *pc_str, + char *sc_str, + unsigned int p_cur, + double *perf, + double *resid) { unsigned int n_repeats = params->n_repeats; unsigned int i; - double time_min = DBL_MAX; - double time; + double time_min = DBL_MAX; + double time; - num_t datatype; + num_t datatype; - dim_t m, k; + dim_t m, k; - uplo_t uploc; - trans_t transa; - trans_t transb; - - obj_t alpha, a, b, beta, c; - obj_t c_save; + uplo_t uploc; + trans_t transa, transb; + obj_t alpha, a, b, beta; + obj_t c, c_ref, c_org_tri, c_result_tri, c_save; // Use the datatype of the first char in the datatype combination string. - bli_param_map_char_to_blis_dt( dc_str[0], &datatype ); + bli_param_map_char_to_blis_dt(dc_str[0], &datatype); // Map the dimension specifier to actual dimensions. - m = libblis_test_get_dim_from_prob_size( op->dim_spec[0], p_cur ); - k = libblis_test_get_dim_from_prob_size( op->dim_spec[1], p_cur ); + m = libblis_test_get_dim_from_prob_size(op->dim_spec[0], p_cur); + k = libblis_test_get_dim_from_prob_size(op->dim_spec[1], p_cur); // Map parameter characters to BLIS constants. - bli_param_map_char_to_blis_uplo( pc_str[0], &uploc ); - bli_param_map_char_to_blis_trans( pc_str[1], &transa ); - bli_param_map_char_to_blis_trans( pc_str[2], &transb ); + bli_param_map_char_to_blis_uplo(pc_str[0], &uploc); + bli_param_map_char_to_blis_trans(pc_str[1], &transa); + bli_param_map_char_to_blis_trans(pc_str[2], &transb); // Create test scalars. - bli_obj_scalar_init_detached( datatype, &alpha ); - bli_obj_scalar_init_detached( datatype, &beta ); + bli_obj_scalar_init_detached(datatype, &alpha); + bli_obj_scalar_init_detached(datatype, &beta); // Create test operands (vectors and/or matrices). - libblis_test_mobj_create( params, datatype, transa, - sc_str[1], m, k, &a ); - libblis_test_mobj_create( params, datatype, transb, - sc_str[2], k, m, &b ); - libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[0], m, m, &c ); - libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[0], m, m, &c_save ); + libblis_test_mobj_create(params, datatype, transa, + sc_str[1], m, k, &a); + libblis_test_mobj_create(params, datatype, transb, + sc_str[2], k, m, &b); + libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE, + sc_str[0], m, m, &c); + libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE, + sc_str[0], m, m, &c_save); + libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE, + sc_str[0], m, m, &c_ref); + libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE, + sc_str[0], m, m, &c_org_tri); + libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE, + sc_str[0], m, m, &c_result_tri); // Set alpha and beta. - if ( bli_obj_is_real( &c ) ) + if (bli_obj_is_real(&c)) { - bli_setsc( 1.2, 0.0, &alpha ); - bli_setsc( 0.9, 0.0, &beta ); + bli_setsc(1.2, 0.0, &alpha); + bli_setsc(-1.0, 0.0, &beta); } else { - bli_setsc( 1.2, 0.8, &alpha ); - bli_setsc( 0.9, 1.0, &beta ); + // For gemmt, both alpha and beta may be complex since, unlike herk, + // C is symmetric in both the real and complex cases. + bli_setsc(1.2, 0.5, &alpha); + bli_setsc(-1.0, 0.5, &beta); } - // Randomize A and B. - libblis_test_mobj_randomize( params, TRUE, &a ); - libblis_test_mobj_randomize( params, TRUE, &b ); + // Randomize A and B + libblis_test_mobj_randomize(params, TRUE, &a); + libblis_test_mobj_randomize(params, TRUE, &b); -//bli_setm( &BLIS_ONE, &a ); -//bli_setm( &BLIS_ONE, &b ); -//bli_setsc( 1.0, 0.0, &alpha ); -//bli_setsc( 0.0, 0.0, &beta ); + // Apply the remaining parameters. + // We need to do this before we create the referece matrix + bli_obj_set_conjtrans(transa, &a); + bli_obj_set_conjtrans(transb, &b); - // Set the uplo property of C. - bli_obj_set_uplo( uploc, &c ); + // We want to create two final matrices + // 1. Input matric c : This will be the random matrix used as input for gemmt + // it needs uplo settings for gemmt to decide which half to be updated + // for the result. + // 2. Refernce matrix C_ref: This matrix is expected output from gemmt + // This matrix is constructed as explain below. + // + // a. c_org_tri: This matrix contains only the original elements from c + // which are not updated by GEMMT operation. All other elements will be set to 0. + // This is constructed by performing the GEMM operation using alpha=beta = 0 + // and setting the uplo to the uplo reqested. + // + // b. c_result_tri: This matrix contains only the elementes that will be updated by gemmt + // This matrix is constructed by doing normal GEMM operation and converting the result + // to trianguler matrix, this will ensure that all other elements excpet the required + // uploc settings are set to 0. + // + // c. Finally c_ref matrix is constucted by adding above to matrices. + // + // 3. GEMMT operation will be performed using a, b & c and the results will be compared + // with c_ref. - // Randomize C, make it densely symmetric, and zero the unstored triangle - // to ensure the implementation reads only from the stored region. - libblis_test_mobj_randomize( params, TRUE, &c ); - bli_mksymm( &c ); - bli_mktrim( &c ); + // + // Assuming that gemmt is done on lower triangle we can represent + // this calculation as. + // + // gemmt(a,b,c) = L(gemm(a,b,c)\U(c) + // = c_results_tri \ C_org_tri + // (beta * C + alpha * A * B) \ C = ((beta * C + alpha * A * B) \ 0) \ (0\C) + // + // C_result_tri = lower trianlge + // C_org_tri = strictly upper triangle. + // "\" represents matrix divided into triangles. + // + // For upper triangle operations the order of lower and upper matrices in + // these euqations will be exchanged. - // Save C and set its uplo property. - bli_setm( &BLIS_ZERO, &c_save ); - bli_obj_set_uplo( uploc, &c_save ); - bli_copym( &c, &c_save ); + // Generate random input matrix + libblis_test_mobj_randomize(params, TRUE, &c); - // Apply the parameters. - bli_obj_set_conjtrans( transa, &a ); - bli_obj_set_conjtrans( transb, &b ); + // Create the requried copies before setting the uplo attribute + bli_copym(&c, &c_save); + bli_copym(&c, &c_org_tri); + bli_copym(&c, &c_result_tri); + bli_obj_set_uplo(uploc, &c); + bli_obj_set_uplo(uploc, &c_save); - // Repeat the experiment n_repeats times and record results. - for ( i = 0; i < n_repeats; ++i ) + // Create c_org_tri matrix using setm operation, this matrix will + // have original values from input matrix "c" for all elements outside + // triangle selected for GEMMT operation. + bli_obj_set_uplo(uploc, &c_org_tri); // Set to request uplo to set all elemnts in triangle to zero + bli_setm(&BLIS_ZERO, &c_org_tri); + bli_obj_toggle_uplo(&c_org_tri); // Toggle uplo now so that untouched triangle is active. + + // GEMMT output is same as GEMM for the triangle selected by uplo + // So we want to extract this triangle from complete GEMM results + // We do this by setting the uplo and converting the results + // to triangluer matrix. + // Perform gemm operation on original inputs + bli_gemm(&alpha, &a, &b, &beta, &c_result_tri); + // Set the values in other triangle to zero by converting it to trianguler matrix + bli_obj_set_uplo(uploc, &c_result_tri); + bli_mktrim(&c_result_tri); + + // Now we have two matrices with opposite triangles set to zero + // c_result_tri: It has output of GEMM in selected triangle (including diagonal) + // Rest of its elements are set to zero. + // c_org_tri: It has values from orignal C matrix in the non-selected triangle + // Rest of the elements including diagonal are set to zero + // The result of the GEMMT operation will be combined matrix of thse two matrics + // So add them togher + bli_setm(&BLIS_ZERO, &c_ref); // Both matrices we are going to add, have uplo settings + // Clear the destination matrix to avoid partial updates + bli_copym(&c_org_tri, &c_ref); + bli_addm(&c_result_tri, &c_ref); + +#if PRINT + bli_printm("c", &c, "%5.2f", ""); + bli_printm("c_org_tri", &c_org_tri, "%5.2f", ""); + bli_printm("c_result_tri", &c_result_tri, "%5.2f", ""); + bli_printm("c_ref", &c_ref, "%5.2f", ""); +#endif + + // Repeat the experiment n_repeats times and record results. + for (i = 0; i < n_repeats; ++i) { - bli_copym( &c_save, &c ); + bli_copym(&c_save, &c); time = bli_clock(); - libblis_test_gemmt_impl( iface, &alpha, &a, &b, &beta, &c ); + libblis_test_gemmt_impl(iface, &alpha, &a, &b, &beta, &c); - time_min = bli_clock_min_diff( time_min, time ); + time_min = bli_clock_min_diff(time_min, time); } // Estimate the performance of the best experiment repeat. - *perf = ( 1.0 * m * m * k ) / time_min / FLOPS_PER_UNIT_PERF; - if ( bli_obj_is_complex( &c ) ) *perf *= 4.0; + *perf = (1.0 * m * m * k) / time_min / FLOPS_PER_UNIT_PERF; + if (bli_obj_is_complex(&c)) + *perf *= 4.0; // Perform checks. - libblis_test_gemmt_check( params, &alpha, &a, &b, &beta, &c, &c_save, resid ); + libblis_test_gemmt_check(params, &alpha, &a, &b, &beta, &c, &c_ref, resid); // Zero out performance and residual if output matrix is empty. - libblis_test_check_empty_problem( &c, perf, resid ); + libblis_test_check_empty_problem(&c, perf, resid); // Free the test objects. - bli_obj_free( &a ); - bli_obj_free( &b ); - bli_obj_free( &c ); - bli_obj_free( &c_save ); + bli_obj_free(&a); + bli_obj_free(&b); + bli_obj_free(&c); + bli_obj_free(&c_ref); + bli_obj_free(&c_org_tri); + bli_obj_free(&c_result_tri); + bli_obj_free(&c_save); } - - -void libblis_test_gemmt_impl - ( - iface_t iface, - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c - ) +void libblis_test_gemmt_impl( + iface_t iface, + obj_t *alpha, + obj_t *a, + obj_t *b, + obj_t *beta, + obj_t *c) { - switch ( iface ) + switch (iface) { - case BLIS_TEST_SEQ_FRONT_END: -#if 0 -//bli_printm( "alpha", alpha, "%5.2f", "" ); -//bli_printm( "beta", beta, "%5.2f", "" ); -bli_printm( "a", a, "%5.2f", "" ); -bli_printm( "b", b, "%5.2f", "" ); -bli_printm( "c", c, "%5.2f", "" ); + case BLIS_TEST_SEQ_FRONT_END: +#if PRINT + bli_printm("a", a, "%5.2f", ""); + bli_printm("b", b, "%5.2f", ""); + bli_printm("c Before", c, "%5.2f", ""); #endif -//if ( bli_obj_length( b ) == 16 && -// bli_obj_stor3_from_strides( c, a, b ) == BLIS_CRR ) -//bli_printm( "c before", c, "%6.3f", "" ); - bli_gemmt( alpha, a, b, beta, c ); -#if 0 -//if ( bli_obj_length( c ) == 12 && -// bli_obj_stor3_from_strides( c, a, b ) == BLIS_RRR ) -bli_printm( "c after", c, "%5.2f", "" ); + + bli_gemmt(alpha, a, b, beta, c); + +#if PRINT + bli_printm("c after", c, "%5.2f", ""); #endif break; - default: - libblis_test_printf_error( "Invalid interface type.\n" ); + default: + libblis_test_printf_error("Invalid interface type.\n"); } } - - -void libblis_test_gemmt_check - ( - test_params_t* params, - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - obj_t* c_orig, - double* resid - ) +void libblis_test_gemmt_check( + test_params_t *params, + obj_t *alpha, + obj_t *a, + obj_t *b, + obj_t *beta, + obj_t *c, + obj_t *c_orig, + double *resid) { - num_t dt = bli_obj_dt( c ); - num_t dt_real = bli_obj_dt_proj_to_real( c ); - uplo_t uploc = bli_obj_uplo( c ); + num_t dt = bli_obj_dt(c); + num_t dt_real = bli_obj_dt_proj_to_real(c); - dim_t m = bli_obj_length( c ); - //dim_t k = bli_obj_width_after_trans( a ); + dim_t m = bli_obj_length(c); - obj_t norm; - obj_t t, v, q, z; + obj_t norm; + obj_t t, v, z; double junk; @@ -340,14 +381,15 @@ void libblis_test_gemmt_check // Pre-conditions: // - a is randomized. // - b is randomized. - // - c_orig is randomized. + // - c is randomized with uplo set + // // Note: // - alpha and beta should have non-zero imaginary components in the // complex cases in order to more fully exercise the implementation. // // Under these conditions, we assume that the implementation for // - // C := beta * C_orig + alpha * transa(A) * transb(B) + // C := beta * C_orig + alpha * transa(A) * transa(B) // // is functioning correctly if // @@ -356,43 +398,36 @@ void libblis_test_gemmt_check // is negligible, where // // v = C * t - // z = ( beta * C_orig + alpha * transa(A) * transb(B) ) * t - // = beta * C_orig * t + alpha * transa(A) * transb(B) * t - // = beta * C_orig * t + alpha * uplo(Q) * t - // = beta * C_orig * t + z + // z = C * C_reference + // // - bli_obj_scalar_init_detached( dt_real, &norm ); + bli_obj_scalar_init_detached(dt_real, &norm); - bli_obj_create( dt, m, 1, 0, 0, &t ); - bli_obj_create( dt, m, 1, 0, 0, &v ); - bli_obj_create( dt, m, 1, 0, 0, &z ); + bli_obj_create(dt, m, 1, 0, 0, &t); + bli_obj_create(dt, m, 1, 0, 0, &v); + bli_obj_create(dt, m, 1, 0, 0, &z); - bli_obj_create( dt, m, m, 0, 0, &q ); - bli_obj_set_uplo( uploc, &q ); + libblis_test_vobj_randomize(params, TRUE, &t); - libblis_test_vobj_randomize( params, TRUE, &t ); + // Ensure result metrix has only selected triangle. + // Calculate V = C * t + bli_gemv(&BLIS_ONE, c, &t, &BLIS_ZERO, &v); + bli_gemv(&BLIS_ONE, c_orig, &t, &BLIS_ZERO, &z); - bli_gemv( &BLIS_ONE, c, &t, &BLIS_ZERO, &v ); - - bli_gemm( &BLIS_ONE, a, b, &BLIS_ZERO, &q ); -#if 1 - bli_mktrim( &q ); - bli_gemv( alpha, &q, &t, &BLIS_ZERO, &z ); -#else - bli_obj_set_struc( BLIS_TRIANGULAR, &q ); - bli_copyv( &t, &z ); - bli_trmv( alpha, &q, &z ); +#if PRINT + bli_printm("c-gemmt", c, "%5.2f", ""); + bli_printm("c-gemm", c_orig, "%5.2f", ""); + bli_printv("v", &v, "%5.2f", ""); + bli_printv("z", &z, "%5.2f", ""); #endif - bli_gemv( beta, c_orig, &t, &BLIS_ONE, &z ); - bli_subv( &z, &v ); - bli_normfv( &v, &norm ); - bli_getsc( &norm, resid, &junk ); + // Find the norm + bli_subv(&z, &v); + bli_normfv(&v, &norm); + bli_getsc(&norm, resid, &junk); - bli_obj_free( &t ); - bli_obj_free( &v ); - bli_obj_free( &z ); - bli_obj_free( &q ); + bli_obj_free(&t); + bli_obj_free(&v); + bli_obj_free(&z); } -