mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Enable f16/f8 mixed precision mode (#820)
* Enable f16/f8 mixed precision * Add an argument to enable mixed precision * Update for compatibility * Add mixed precision example * Introduce ComputeType argument
This commit is contained in:
@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, true>;
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
true>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, false>;
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
false>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
|
||||
@@ -65,7 +65,8 @@ template <typename ALayout,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename ComputeType = CDataType>
|
||||
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
PipelineVer,
|
||||
ComputeType>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user