mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add support for mixed precision in contraction scale and bilinear (#973)
* Add support for mixed precision in contraction scale and bilinear (#936) * Extract common functionality to separate files * Reference contraction: Remove incorrect consts from type_converts * Reference contraction: Add missing type_convert for dst value * Reference contraction: Fix incorrect order of B matrix dimensions * Add support for mixed precision in contraction scale and bilinear * Move using statements from instances to a common file * Move using statements from examples to a common file * Fix the order of B matrix dimensions across examples and profiler * Fix the computation of error threshold * Make ComputeDataType an optional argument * Include possible DataType -> ComputeDataType casting error in the threshold * Remove commented code * Make the ComputeDataType an optional argument in instance --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
73743aa0aa
commit
4ef704d8a6
@@ -33,7 +33,8 @@ template <index_t NumDimM,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeDataType = ADataType>
|
||||
struct DeviceContractionMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
@@ -145,7 +145,8 @@ template <index_t NumDimM,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename ComputeDataType = ADataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
: public DeviceContractionMultipleD<NumDimM,
|
||||
NumDimN,
|
||||
@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
|
||||
Reference in New Issue
Block a user