From 19613902b58d402c883e033be37ba8a647bcb5a6 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 5 Sep 2021 12:41:28 -0500 Subject: [PATCH] GEMM driver and kernel (#29) * add gemm driver * tweak * add gemm kernel: mk_kn_mn and km_kn_mn * tweak * add GEMM km_nk_mn * fix comment --- host/driver_offline/CMakeLists.txt | 3 + .../include/device_gemm_xdlops_km_kn_mn.hpp | 219 +++++++++++++ .../include/device_gemm_xdlops_km_nk_mn.hpp | 219 +++++++++++++ .../include/device_gemm_xdlops_mk_kn_mn.hpp | 219 +++++++++++++ .../include/device_gemm_xdlops_mk_nk_mn.hpp | 275 ++++++++++++++++ .../src/gemm_driver_offline.cpp | 294 ++++++++++++++++++ host/host_tensor/include/gemm_common.hpp | 12 + host/host_tensor/include/host_gemm.hpp | 87 ++++++ script/run.sh | 23 +- 9 files changed, 1345 insertions(+), 6 deletions(-) create mode 100644 host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp create mode 100644 host/driver_offline/src/gemm_driver_offline.cpp create mode 100644 host/host_tensor/include/gemm_common.hpp create mode 100644 host/host_tensor/include/host_gemm.hpp diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt index 8dec70d03f..a3b3613293 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/host/driver_offline/CMakeLists.txt @@ -14,11 +14,14 @@ include_directories(BEFORE set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp) +set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE}) +add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE}) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor) +target_link_libraries(gemm_driver_offline PRIVATE host_tensor) diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp new file mode 100644 index 0000000000..d9169649e6 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp @@ -0,0 +1,219 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, + const BDesc& b_k_n_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_k_m, + const Tensor& b_k_n, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_k_m_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_k_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp new file mode 100644 index 0000000000..90e258d581 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp @@ -0,0 +1,219 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, + const BDesc& b_n_k_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_k_m, + const Tensor& b_n_k, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_n_k_grid_desc.GetLength(I0); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_k_m_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_n_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp new file mode 100644 index 0000000000..ab235d97e7 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -0,0 +1,219 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, + const BDesc& b_k_n_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_m_k_grid_desc.GetLength(I1); + const auto M = a_m_k_grid_desc.GetLength(I0); + const auto N = b_k_n_grid_desc.GetLength(I1); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_m_k_grid_desc, + make_tuple(make_pass_through_transform(M), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_k_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp new file mode 100644 index 0000000000..c68442d127 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp @@ -0,0 +1,275 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, + const BDesc& b_n_k_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_m_k, + const Tensor& b_n_k, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_m_k_grid_desc.GetLength(I1); + const auto M = a_m_k_grid_desc.GetLength(I0); + const auto N = b_n_k_grid_desc.GetLength(I0); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_m_k_grid_desc, + make_tuple(make_pass_through_transform(M), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_n_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp new file mode 100644 index 0000000000..42c69ff6a2 --- /dev/null +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -0,0 +1,294 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "gemm_common.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdlops_mk_kn_mn.hpp" +#include "device_gemm_xdlops_mk_nk_mn.hpp" +#include "device_gemm_xdlops_km_kn_mn.hpp" +#include "device_gemm_xdlops_km_nk_mn.hpp" + +#define USE_GEMM_XDL_MK_KN_MN 1 +#define USE_GEMM_XDL_MK_NK_MN 1 +#define USE_GEMM_XDL_KM_KN_MN 1 +#define USE_GEMM_XDL_KM_NK_MN 1 + +enum GemmAlgo +{ + Xdl_MK_KN_MN, // 0 + Xdl_MK_NK_MN, // 1 + Xdl_KM_KN_MN, // 2 + Xdl_KM_NK_MN, // 3 +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + // dynamic mode + if(argc != 10) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: M, N, K\n"); + exit(1); + } + + const auto layout = static_cast(std::stoi(argv[1])); + const auto algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t M = std::stoi(argv[7]); + const index_t N = std::stoi(argv[8]); + const index_t K = std::stoi(argv[9]); + +#if 0 + using ab_data_t = float; + using acc_data_t = float; + using c_data_t = float; +#elif 1 + using ab_data_t = half_t; + using acc_data_t = float; + using c_data_t = half_t; +#elif 1 + using ab_data_t = int8_t; + using acc_data_t = int32_t; + using c_data_t = int8_t; +#endif + + std::vector a_lengths_host(2), b_lengths_host(2), c_lengths_host(2); + std::vector a_strides_host(2), b_strides_host(2), c_strides_host(2); + + if(layout == GemmMatrixLayout::MK_KN_MN) + { + a_lengths_host[0] = static_cast(M); + a_lengths_host[1] = static_cast(K); + a_strides_host[0] = static_cast(K); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(K); + b_lengths_host[1] = static_cast(N); + b_strides_host[0] = static_cast(N); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + a_lengths_host[0] = static_cast(M); + a_lengths_host[1] = static_cast(K); + a_strides_host[0] = static_cast(K); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(N); + b_lengths_host[1] = static_cast(K); + b_strides_host[0] = static_cast(K); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + a_lengths_host[0] = static_cast(K); + a_lengths_host[1] = static_cast(M); + a_strides_host[0] = static_cast(M); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(K); + b_lengths_host[1] = static_cast(N); + b_strides_host[0] = static_cast(N); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else if(layout == GemmMatrixLayout::KM_NK_MN) + { + a_lengths_host[0] = static_cast(K); + a_lengths_host[1] = static_cast(M); + a_strides_host[0] = static_cast(M); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(N); + b_lengths_host[1] = static_cast(K); + b_strides_host[0] = static_cast(K); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else + { + std::runtime_error("wrong! not implemented"); + } + + Tensor a(a_lengths_host, a_strides_host); + Tensor b(b_lengths_host, b_strides_host); + Tensor c_host(c_lengths_host, c_strides_host); + Tensor c_device(c_lengths_host, c_strides_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: "); + ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: "); + ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: "); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + auto f_make_for_device_mk_kn_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + + auto f_make_for_device_mk_nk_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + + auto f_make_for_device_km_kn_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + + auto f_make_for_device_km_nk_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + +#if USE_GEMM_XDL_MK_KN_MN + if(algo == GemmAlgo::Xdl_MK_KN_MN) + { + if(layout != GemmMatrixLayout::MK_KN_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_mk_kn_mn(); + + device_gemm_xdlops_mk_kn_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_MK_NK_MN + if(algo == GemmAlgo::Xdl_MK_NK_MN) + { + if(layout != GemmMatrixLayout::MK_NK_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_mk_nk_mn(); + + device_gemm_xdlops_mk_nk_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_KN_MN + if(algo == GemmAlgo::Xdl_KM_KN_MN) + { + if(layout != GemmMatrixLayout::KM_KN_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_km_kn_mn(); + + device_gemm_xdlops_km_kn_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_NK_MN + if(algo == GemmAlgo::Xdl_KM_NK_MN) + { + if(layout != GemmMatrixLayout::KM_NK_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_km_nk_mn(); + + device_gemm_xdlops_km_nk_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + + if(do_verification) + { + host_gemm(a, b, c_host, layout); + + check_error(c_host, c_device); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_device.mData, ",") << std::endl; + } + } +} diff --git a/host/host_tensor/include/gemm_common.hpp b/host/host_tensor/include/gemm_common.hpp new file mode 100644 index 0000000000..f0f35a78b9 --- /dev/null +++ b/host/host_tensor/include/gemm_common.hpp @@ -0,0 +1,12 @@ +#ifndef GEMM_COMMON_HPP +#define GEMM_COMMON_HPP + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +#endif diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp new file mode 100644 index 0000000000..97cf245054 --- /dev/null +++ b/host/host_tensor/include/host_gemm.hpp @@ -0,0 +1,87 @@ +#pragma once +#include "host_tensor.hpp" +#include "gemm_common.hpp" + +template +void host_gemm(const Tensor& a, + const Tensor& b, + Tensor& c, + const GemmMatrixLayout layout) +{ + if(layout == GemmMatrixLayout::MK_KN_MN) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + auto f_mk_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + auto f_km_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_MN) + { + auto f_km_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} diff --git a/script/run.sh b/script/run.sh index ecb5c85d81..3b383fcf3a 100755 --- a/script/run.sh +++ b/script/run.sh @@ -12,13 +12,16 @@ #export OLC_DEBUG_HIP_DUMP=1 #export OLC_DEBUG_SAVE_TEMP_DIR=1 - make -j conv_fwd_driver_offline - make -j conv_bwd_driver_offline - make -j conv_fwd_driver_online - #rm -rf /root/_hip_binary_kernels_/ #rm -rf /tmp/olCompile* +#make -j conv_fwd_driver_offline +#make -j conv_bwd_driver_offline +#make -j conv_wrw_driver_offline +#make -j conv_fwd_driver_online + + make -j gemm_driver_offline + LAYOUT=$1 ALGO=$2 VERIFY=$3 @@ -30,7 +33,7 @@ REPEAT=$6 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 - ./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 @@ -44,4 +47,12 @@ REPEAT=$6 #./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./host/driver_offline/conv_wrw_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 + +#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 + +################################################ layout algo verify init log repeat M___ N___ K___ +#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 +#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 + ./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 +#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192