mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Add support for mixed precision in contraction scale and bilinear (#973)
* Add support for mixed precision in contraction scale and bilinear (#936)
* Extract common functionality to separate files
* Reference contraction: Remove incorrect consts from type_converts
* Reference contraction: Add missing type_convert for dst value
* Reference contraction: Fix incorrect order of B matrix dimensions
* Add support for mixed precision in contraction scale and bilinear
* Move using statements from instances to a common file
* Move using statements from examples to a common file
* Fix the order of B matrix dimensions across examples and profiler
* Fix the computation of error threshold
* Make ComputeDataType an optional argument
* Include possible DataType -> ComputeDataType casting error in the threshold
* Remove commented code
* Make the ComputeDataType an optional argument in instance
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: 4ef704d8a6]
This commit is contained in:
committed by
GitHub
parent
3eaee9196f
commit
382a513acb
@@ -145,7 +145,8 @@ template <index_t NumDimM,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename ComputeDataType = ADataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
: public DeviceContractionMultipleD<NumDimM,
|
||||
NumDimN,
|
||||
@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
|
||||
Reference in New Issue
Block a user