From 253ceffb0fea667ddee34ba0b06e103a4a19e4d5 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Thu, 16 Mar 2023 11:53:16 +0000 Subject: [PATCH] Redirecting scalm to setm when alpha is zero Added a check in scalm framework for alpha=0. Set the output matrix to zero when alpha=0. This ensures that any Inf or NaN values in the matrix are not propagated to the output matrix. AMD-Internal: [CPUPL-3053] Change-Id: I62b9b5405be220eb4df97aadda14701abcccb475 --- frame/1m/bli_l1m_tapi.c | 66 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/frame/1m/bli_l1m_tapi.c b/frame/1m/bli_l1m_tapi.c index 2b3c4bb4a..bfcc38fd5 100644 --- a/frame/1m/bli_l1m_tapi.c +++ b/frame/1m/bli_l1m_tapi.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -336,6 +337,70 @@ void PASTEMAC2(ch,opname,EX_SUF) \ INSERT_GENTFUNC_BASIC0( scal2m ) +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTEMAC2(ch,opname,EX_SUF) \ + ( \ + conj_t conjalpha, \ + doff_t diagoffx, \ + diag_t diagx, \ + uplo_t uplox, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* x, inc_t rs_x, inc_t cs_x \ + BLIS_TAPI_EX_PARAMS \ + ) \ +{ \ + bli_init_once(); \ +\ + BLIS_TAPI_EX_DECLS \ +\ + if ( bli_zero_dim2( m, n ) ) return; \ +\ + /* Obtain a valid context from the gks if necessary. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ +\ + /* Invoke setm function if alpha is zero. */ \ + if ( PASTEMAC(ch,eq0)(*alpha)) \ + { \ + PASTEMAC2(ch,setm,_unb_var1) \ + ( \ + conjalpha, \ + diagoffx, \ + diagx, \ + uplox, \ + m, \ + n, \ + alpha, \ + x, rs_x, cs_x, \ + cntx, \ + rntm \ + ); \ + } \ + else \ + { \ + /* Invoke the helper variant, which loops over the appropriate kernel + to implement the current operation. */ \ + PASTEMAC2(ch,opname,_unb_var1) \ + ( \ + conjalpha, \ + diagoffx, \ + diagx, \ + uplox, \ + m, \ + n, \ + alpha, \ + x, rs_x, cs_x, \ + cntx, \ + rntm \ + ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0( scalm ) + #undef GENTFUNC #define GENTFUNC( ctype, ch, opname ) \ \ @@ -378,7 +443,6 @@ void PASTEMAC2(ch,opname,EX_SUF) \ ); \ } -INSERT_GENTFUNC_BASIC0( scalm ) INSERT_GENTFUNC_BASIC0( setm )