From a4876a8c679ef77e31112613af141d9a3270883f Mon Sep 17 00:00:00 2001 From: zjing14 Date: Mon, 12 Feb 2024 11:45:42 -0600 Subject: [PATCH] Optimizing fp8_fp16 mixedprec gemm (#1150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add delayed cvt * extend fp16 gemm_splitk instances for fp8_fp16 gemm * add f8 example * add 128 kperblk instances for fp8 * add kpb128 instance * added more instances into kpb128 * clean code * clean code * fix * fix * fixed * Update example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp Co-authored-by: Bartłomiej Kocot * Update library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp Co-authored-by: Bartłomiej Kocot --------- Co-authored-by: Jing Zhang Co-authored-by: Bartłomiej Kocot [ROCm/composable_kernel commit: 602c4cc0d934e78bf2f2c07d95916398014bd767] --- example/35_splitK_gemm/CMakeLists.txt | 3 + .../splitK_gemm_xdl_fp16_fp8.cpp | 60 +++++++ .../gpu/block/blockwise_gemm_xdlops.hpp | 69 ++++---- .../impl/device_gemm_xdl_splitk_c_shuffle.hpp | 12 +- .../element/unary_element_wise_operation.hpp | 39 ----- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 25 +-- .../threadwise_tensor_slice_transfer.hpp | 61 ++++++-- .../gpu/gemm_splitk.hpp | 6 + .../gpu/gemm_splitk/CMakeLists.txt | 79 +++++----- ...k_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp | 147 ++++++++++++++++++ 10 files changed, 370 insertions(+), 131 deletions(-) create mode 100644 example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index f98308d687..5277b32f63 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) + add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) + add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp new file mode 100644 index 0000000000..b93639e6c1 --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 16, 64, 8, 16, 16, 16, 1, 2, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Default, ADataType, BDataType>; + +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 904a96cc9f..701dd04f6c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -37,7 +37,9 @@ template + index_t KPack, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = FloatB> struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { static constexpr auto I0 = Number<0>{}; @@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf); static_for<0, KPerThread, KPack>{}([&](auto k) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf + a_thread_vec.template AsType()(i) = a_thread_buf [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf + b_thread_vec.template AsType()(i) = b_thread_buf [Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -398,6 +401,8 @@ template struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + KPack, + ComputeTypeA, + ComputeTypeB> { using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + KPack, + ComputeTypeA, + ComputeTypeB>; #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING using Base::a_block_desc_m0_m1_m2_k; @@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { @@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -586,7 +595,9 @@ template + LoopScheduler LoopSched, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = FloatB> constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() { if constexpr(LoopSched == LoopScheduler::Default) @@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() NPerXDL, MRepeat, NRepeat, - KPack>{}; + KPack, + ComputeTypeA, + ComputeTypeB>{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { @@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() NPerXDL, MRepeat, NRepeat, - KPack>{}; + KPack, + ComputeTypeA, + ComputeTypeB>{}; } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index 86c025aa6f..7f28ec7680 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -60,7 +60,9 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename LDSTypeA = ComputeType, + typename LDSTypeB = ComputeType> struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; struct Argument : public GridwiseGemm::Argument { diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 70c72bf768..33c2cb6c6d 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -21,50 +21,11 @@ struct PassThroughPack2 template __host__ __device__ void operator()(Y& y, const X& x) const; - __host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const - { - // fake conversion - uint16_t t = ck::bit_cast(x); - y = ck::bit_cast(t); - } - __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const { auto t = type_convert(x); y = type_convert(t); } - - __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const - { - y = x; - } - - constexpr const static bool is_pack2_invocable = true; }; struct PassThrough diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 6cbb834395..b52f5c51b1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -9,7 +9,6 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -96,7 +95,10 @@ template + typename ComputeTypeA = FloatC, + typename ComputeTypeB = ComputeTypeA, + typename LDSTypeA = ComputeTypeA, + typename LDSTypeB = ComputeTypeB> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { static constexpr auto I0 = Number<0>{}; @@ -430,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType), + return math::max(a_block_space_size * sizeof(LDSTypeA) + + b_block_space_size * sizeof(LDSTypeB), c_block_size * sizeof(FloatC)); } @@ -785,7 +788,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, FloatA, - ComputeType, + LDSTypeA, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -815,7 +818,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, FloatB, - ComputeType, + LDSTypeB, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -845,8 +848,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, // ComputeType A - ComputeType, // ComputeType B + LDSTypeA, + LDSTypeB, FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), @@ -855,7 +858,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 MRepeat, NRepeat, K1, - LoopSched>(); + LoopSched, + ComputeTypeA, + ComputeTypeB>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -863,8 +868,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto a_block_space_size = math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - ComputeType* p_a_block = static_cast(p_shared_block); - ComputeType* p_b_block = static_cast(p_shared_block) + a_block_space_size; + auto p_a_block = reinterpret_cast(p_shared_block); + auto p_b_block = reinterpret_cast(p_a_block + a_block_space_size); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 2774214079..608679a4fa 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -1156,27 +1158,56 @@ struct ThreadwiseTensorSliceTransfer_v4 src_ref_to_origin_disp_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - // apply type convert src_tmp_vector.template AsType()(i) = src_buf[Number{}]; }); } - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - // TODO: if SrcData and DstData are vetor type, then static_cast may not compile - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - dst_tmp_vector.template AsType()(i) = - type_convert(src_tmp_vector.template AsType()[i]); - }); + if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + constexpr index_t pack_size = 2; - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); + using dst_v_t = typename vector_type_maker_t::type; + using src_v_t = typename vector_type_maker_t::type; + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack2{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } }); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp index ebbe7c7211..863eddef24 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -189,6 +189,11 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances( DeviceGemmSplitK>>& instances); +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances( + std::vector>>& + instances); + void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances( std::vector>>& @@ -352,6 +357,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances(op_ptrs); add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances(op_ptrs); add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index a4d23914dd..059b6a720f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -1,42 +1,45 @@ set(GEMM_SPLITK_INSTANCES) -list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp - device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp) +list(APPEND GEMM_SPLITK_INSTANCES + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp + device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp + ) add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp new file mode 100644 index 0000000000..0409dec369 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 16, 16, 16, 1, 1, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 16, 16, 16, 1, 2, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 16, 16, 16, 1, 4, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances( + std::vector>>& + instances) +{ + // default + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmDefault, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + + // MNKPadding + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + + // KPadding + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + + // MNPadding + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck