mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Fp16/fp8 mixed-precision Gemm with multiply+add fusion (#865)
* add compute_type * add multiply_add ckProfiler * add f8_fp16 support * clean * clean * fixed lds size calc * format --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -543,9 +543,13 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
|
||||
@@ -331,8 +331,13 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
|
||||
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
|
||||
};
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<>, // DsDataType,
|
||||
|
||||
@@ -324,8 +324,12 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
index_t BatchStrideE_;
|
||||
};
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
|
||||
@@ -310,9 +310,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
|
||||
@@ -20,7 +20,8 @@
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ABDataType,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
@@ -36,8 +37,8 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
@@ -242,9 +243,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
using ComputeDataType = EDataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -442,6 +447,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
BDataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
|
||||
@@ -355,9 +355,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
|
||||
@@ -355,6 +355,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
|
||||
@@ -367,9 +367,13 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
|
||||
@@ -228,9 +228,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
|
||||
Reference in New Issue
Block a user