From 54ae3063987a7ca5cdb0c56ae59534a9dbc92659 Mon Sep 17 00:00:00 2001 From: darren-amd Date: Tue, 28 Jan 2025 09:58:39 -0500 Subject: [PATCH] Change flag to CK_GFX90A_DENORM_WORKAROUND (#1817) * Change flag from CK_WORKAROUND_DENORM_FIX to CK_GFX90A_DENORM_WORKAROUND for more clarity. Also changed the definition macros to be more clear. [ROCm/composable_kernel commit: d6a4605e1c77d73f5387bb4e5735411dafd6fb83] --- include/ck/ck.hpp | 17 +++++++++++------ .../gridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 2 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 2 +- ..._multiple_d_xdl_cshuffle_lds_direct_load.hpp | 2 +- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 2 +- 6 files changed, 16 insertions(+), 11 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index d876f8fcb3..fc9d074716 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -235,13 +235,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 -// denorm test fix, required to work around dissue -#ifndef CK_WORKAROUND_DENORM_FIX -#define CK_WORKAROUND_DENORM_FIX 0 +// denorm test fix, necessary for gfx90a +#ifndef CK_GFX90A_DENORM_WORKAROUND +#define CK_GFX90A_DENORM_WORKAROUND 0 +#endif // CK_GFX90A_DENORM_WORKAROUND +// Enable only for gfx90a +#if defined(__gfx90a__) +#if CK_GFX90A_DENORM_WORKAROUND +#define CK_GFX90A_DENORM_WORKAROUND 1 +#endif // CK_GFX90A_DENORM_WORKAROUND is set to 1 #else -// enable only for gfx90a -#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) -#endif // CK_WORKAROUND_DENORM_FIX +#define CK_GFX90A_DENORM_WORKAROUND 0 +#endif // gfx90a // set flag to 1 to build deprecated instances #define CK_BUILD_DEPRECATED 1 diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 60c02d64e1..150dd98064 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -101,7 +101,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; using BComputeDataType = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index e6085fad8c..4b344c02f8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -100,7 +100,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; using BComputeDataType = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index cd36b9e51a..b4c5d004c4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -164,7 +164,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; #else diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 5617f67f8b..b41e747a3a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -271,7 +271,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // when mfma if fixed, remove this section and update // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, // throughout this file -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using FloatAAdjusted = conditional_t, ck::bhalf_t, ComputeTypeA>; using FloatBAdjusted = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 4f3caff248..5c3d9b7ba4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -254,7 +254,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update // FloatABAdjusted -> FloatAB throughout this file -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; #else using FloatABAdjusted = FloatAB;