mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Optimize bf16 conversion (#664)
* Add TypeConvert class and start refactoring * Refactor TypeConvert as a struct * Get back to template functions type_convert * Add a type_convert_bf16_rtn, set rtz as default * Clean up * Add UnaryConvertPrecision struct for high-precision workloads * Format * Update type_convert to UnaryConvert on threadwise level * Update UnaryConvertPrecision * Format * Fix chmod * Add a flag to pick converion method * Format * Remove the added flag * Merge elementwise op with type conversion * Move type_convert to elemwise op, update the op * Update type_convert_precision -> bf16_convert_rtn * Clean up * Update comments * Update the CK_WORKAROUND_DENORM_FIX flag handling * Update the unneeded op to work but warn user * Remove the message * Use a PassThrough instead of ConvertBF16RTN to calcaulate reference * Format * Add missing include
This commit is contained in:
@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// ABDataTypeAdjusted -> ABDataType throughout this file
|
||||
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using ABDataTypeAdjusted =
|
||||
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
|
||||
#else
|
||||
|
||||
@@ -266,7 +266,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// FloatABAdjusted -> FloatAB throughout this file
|
||||
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
|
||||
#else
|
||||
using FloatABAdjusted = FloatAB;
|
||||
|
||||
@@ -136,7 +136,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// FloatABAdjusted -> FloatAB throughout this file
|
||||
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
|
||||
#else
|
||||
using FloatABAdjusted = FloatAB;
|
||||
|
||||
Reference in New Issue
Block a user