Enable f16/f8 mixed precision mode (#820)

* Enable f16/f8 mixed precision

* Add an argument to enable mixed precision

* Update for compatibility

* Add mixed precision example

* Introduce ComputeType argument
This commit is contained in:
Rostyslav Geyyer
2023-08-09 08:44:23 -05:00
committed by GitHub
parent 6802611334
commit 9c54eaab04
5 changed files with 89 additions and 29 deletions

View File

@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
ALayout,
BLayout,
CLayout,
ADataType, // TODO: distinguish A/B datatype
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, true>;
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType,
BDataType,
CDataType,
true>;
ave_time += launch_and_time_kernel(stream_config,
kernel,
@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, false>;
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType,
BDataType,
CDataType,
false>;
ave_time += launch_and_time_kernel(stream_config,
kernel,

View File

@@ -65,7 +65,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = CDataType>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout,
CLayout,
@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ALayout,
BLayout,
CLayout,
ADataType, // TODO: distinguish A/B datatype
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
PipelineVer,
ComputeType>;
using Argument = typename GridwiseGemm::Argument;