mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add support for mixed precision in contraction scale and bilinear (#936)
* Extract common functionality to separate files * Reference contraction: Remove incorrect consts from type_converts * Reference contraction: Add missing type_convert for dst value * Reference contraction: Fix incorrect order of B matrix dimensions * Add support for mixed precision in contraction scale and bilinear * Move using statements from instances to a common file * Move using statements from examples to a common file * Fix the order of B matrix dimensions across examples and profiler * Fix the computation of error threshold * Make ComputeDataType an optional argument * Include possible DataType -> ComputeDataType casting error in the threshold * Remove commented code
This commit is contained in:
committed by
GitHub
parent
cb53874002
commit
f07485060e
@@ -112,6 +112,7 @@ template <index_t NumDimM,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ComputeDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
|
||||
Reference in New Issue
Block a user