From 8ef2986daeafdb5213793af572bccdbeb26749c3 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 9 Aug 2023 08:44:23 -0500 Subject: [PATCH] Enable f16/f8 mixed precision mode (#820) * Enable f16/f8 mixed precision * Add an argument to enable mixed precision * Update for compatibility * Add mixed precision example * Introduce ComputeType argument [ROCm/composable_kernel commit: 9c54eaab04e6db605dc86f1d1ab16bd04f51fc89] --- example/01_gemm/CMakeLists.txt | 3 ++ example/01_gemm/gemm_xdl_fp16_f8.cpp | 41 ++++++++++++++++ .../impl/device_cgemm_4gemm_xdl_cshuffle.hpp | 17 +++++-- .../device/impl/device_gemm_xdl_cshuffle.hpp | 9 ++-- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 48 +++++++++++-------- 5 files changed, 89 insertions(+), 29 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_fp16_f8.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 7ce866931c..f8b886ebbf 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -64,3 +64,6 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) add_dependencies(example_gemm_xdl example_gemm_xdl_f8) endif() endif() + +add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) +add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) diff --git a/example/01_gemm/gemm_xdl_fp16_f8.cpp b/example/01_gemm/gemm_xdl_fp16_f8.cpp new file mode 100644 index 0000000000..d3cf3d397a --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_f8.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeType = ck::half_t; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| ComputeType| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| | +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index 106746716d..a095521161 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ALayout, BLayout, CLayout, - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, GemmAccDataType, CShuffleDataType, CDataType, @@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v1; + const auto kernel = kernel_gemm_xdl_cshuffle_v1; ave_time += launch_and_time_kernel(stream_config, kernel, @@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle } else { - const auto kernel = - kernel_gemm_xdl_cshuffle_v1; + const auto kernel = kernel_gemm_xdl_cshuffle_v1; ave_time += launch_and_time_kernel(stream_config, kernel, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp index 4c9ddbab80..86bc7b5bbb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -65,7 +65,8 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeType = CDataType> struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; + PipelineVer, + ComputeType>; using Argument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 7f1fb21d33..29cbf14a57 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -35,13 +35,17 @@ __global__ void #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } -template +template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, typename GridwiseGemm::Problem problem) { @@ -61,7 +65,8 @@ __global__ void template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeType = FloatC> struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // Argument struct Argument : public tensor_operation::device::BaseArgument, public Problem { - __host__ Argument(const FloatAB* p_a_grid_, - const FloatAB* p_b_grid_, + __host__ Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, FloatC* p_c_grid_, index_t M_, index_t N_, @@ -479,8 +485,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { } - const FloatAB* p_a_grid; - const FloatAB* p_b_grid; + const FloatA* p_a_grid; + const FloatB* p_b_grid; FloatC* p_c_grid; }; @@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), + return math::max((a_block_space_size_aligned * sizeof(ComputeType) + + b_block_space_size_aligned * sizeof(ComputeType)), c_block_size * sizeof(FloatCShuffle)); } @@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const Problem& problem) @@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatA, + ComputeType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 Sequence, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatB, + ComputeType, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // sanity check constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - FloatAB, + ComputeType, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);