mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
GEMM driver and kernel (#29)
* add gemm driver * tweak * add gemm kernel: mk_kn_mn and km_kn_mn * tweak * add GEMM km_nk_mn * fix comment
This commit is contained in:
@@ -14,11 +14,14 @@ include_directories(BEFORE
|
||||
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
|
||||
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
|
||||
set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp)
|
||||
set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp)
|
||||
|
||||
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE})
|
||||
|
||||
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_driver_offline PRIVATE host_tensor)
|
||||
|
||||
219
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
Normal file
219
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
Normal file
@@ -0,0 +1,219 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType,
|
||||
typename AccType,
|
||||
typename CType,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc>
|
||||
void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
|
||||
const BDesc& b_k_n_grid_desc,
|
||||
const CDesc& c_m_n_grid_desc,
|
||||
const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k_m_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_k_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
219
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
Normal file
219
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
Normal file
@@ -0,0 +1,219 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType,
|
||||
typename AccType,
|
||||
typename CType,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc>
|
||||
void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
|
||||
const BDesc& b_n_k_grid_desc,
|
||||
const CDesc& c_m_n_grid_desc,
|
||||
const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_n_k_grid_desc.GetLength(I0);
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k_m_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_n_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_unmerge_transform(make_tuple(K0, K1Number))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
219
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
Normal file
219
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
Normal file
@@ -0,0 +1,219 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType,
|
||||
typename AccType,
|
||||
typename CType,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc>
|
||||
void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
|
||||
const BDesc& b_k_n_grid_desc,
|
||||
const CDesc& c_m_n_grid_desc,
|
||||
const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k_grid_desc.GetLength(I1);
|
||||
const auto M = a_m_k_grid_desc.GetLength(I0);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_m_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(M),
|
||||
make_unmerge_transform(make_tuple(K0, K1Number))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_k_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
275
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
Normal file
275
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
Normal file
@@ -0,0 +1,275 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType,
|
||||
typename AccType,
|
||||
typename CType,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc>
|
||||
void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
|
||||
const BDesc& b_n_k_grid_desc,
|
||||
const CDesc& c_m_n_grid_desc,
|
||||
const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k_grid_desc.GetLength(I1);
|
||||
const auto M = a_m_k_grid_desc.GetLength(I0);
|
||||
const auto N = b_n_k_grid_desc.GetLength(I0);
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_m_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(M),
|
||||
make_unmerge_transform(make_tuple(K0, K1Number))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_n_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_unmerge_transform(make_tuple(K0, K1Number))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
294
host/driver_offline/src/gemm_driver_offline.cpp
Normal file
294
host/driver_offline/src/gemm_driver_offline.cpp
Normal file
@@ -0,0 +1,294 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
#include "host_gemm.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_xdlops_mk_kn_mn.hpp"
|
||||
#include "device_gemm_xdlops_mk_nk_mn.hpp"
|
||||
#include "device_gemm_xdlops_km_kn_mn.hpp"
|
||||
#include "device_gemm_xdlops_km_nk_mn.hpp"
|
||||
|
||||
#define USE_GEMM_XDL_MK_KN_MN 1
|
||||
#define USE_GEMM_XDL_MK_NK_MN 1
|
||||
#define USE_GEMM_XDL_KM_KN_MN 1
|
||||
#define USE_GEMM_XDL_KM_NK_MN 1
|
||||
|
||||
enum GemmAlgo
|
||||
{
|
||||
Xdl_MK_KN_MN, // 0
|
||||
Xdl_MK_NK_MN, // 1
|
||||
Xdl_KM_KN_MN, // 2
|
||||
Xdl_KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
// dynamic mode
|
||||
if(argc != 10)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: M, N, K\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
const auto algo = static_cast<GemmAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t M = std::stoi(argv[7]);
|
||||
const index_t N = std::stoi(argv[8]);
|
||||
const index_t K = std::stoi(argv[9]);
|
||||
|
||||
#if 0
|
||||
using ab_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using c_data_t = float;
|
||||
#elif 1
|
||||
using ab_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using c_data_t = half_t;
|
||||
#elif 1
|
||||
using ab_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using c_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> a_lengths_host(2), b_lengths_host(2), c_lengths_host(2);
|
||||
std::vector<std::size_t> a_strides_host(2), b_strides_host(2), c_strides_host(2);
|
||||
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
a_strides_host[0] = static_cast<std::size_t>(K);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
b_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
b_strides_host[0] = static_cast<std::size_t>(N);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
a_strides_host[0] = static_cast<std::size_t>(K);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
b_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
b_strides_host[0] = static_cast<std::size_t>(K);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(M);
|
||||
a_strides_host[0] = static_cast<std::size_t>(M);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
b_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
b_strides_host[0] = static_cast<std::size_t>(N);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(M);
|
||||
a_strides_host[0] = static_cast<std::size_t>(M);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
b_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
b_strides_host[0] = static_cast<std::size_t>(K);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<ab_data_t> a(a_lengths_host, a_strides_host);
|
||||
Tensor<ab_data_t> b(b_lengths_host, b_strides_host);
|
||||
Tensor<c_data_t> c_host(c_lengths_host, c_strides_host);
|
||||
Tensor<c_data_t> c_device(c_lengths_host, c_strides_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: ");
|
||||
ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: ");
|
||||
ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: ");
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
a.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_mk_kn_mn = [&]() {
|
||||
const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
|
||||
const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1));
|
||||
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
|
||||
return make_tuple(a_desc, b_desc, c_desc);
|
||||
};
|
||||
|
||||
auto f_make_for_device_mk_nk_mn = [&]() {
|
||||
const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
|
||||
const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1));
|
||||
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
|
||||
return make_tuple(a_desc, b_desc, c_desc);
|
||||
};
|
||||
|
||||
auto f_make_for_device_km_kn_mn = [&]() {
|
||||
const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1));
|
||||
const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1));
|
||||
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
|
||||
return make_tuple(a_desc, b_desc, c_desc);
|
||||
};
|
||||
|
||||
auto f_make_for_device_km_nk_mn = [&]() {
|
||||
const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1));
|
||||
const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1));
|
||||
const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
|
||||
return make_tuple(a_desc, b_desc, c_desc);
|
||||
};
|
||||
|
||||
#if USE_GEMM_XDL_MK_KN_MN
|
||||
if(algo == GemmAlgo::Xdl_MK_KN_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto descs = f_make_for_device_mk_kn_mn();
|
||||
|
||||
device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(
|
||||
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_MK_NK_MN
|
||||
if(algo == GemmAlgo::Xdl_MK_NK_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto descs = f_make_for_device_mk_nk_mn();
|
||||
|
||||
device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(
|
||||
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_KN_MN
|
||||
if(algo == GemmAlgo::Xdl_KM_KN_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto descs = f_make_for_device_km_kn_mn();
|
||||
|
||||
device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(
|
||||
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_NK_MN
|
||||
if(algo == GemmAlgo::Xdl_KM_NK_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto descs = f_make_for_device_km_nk_mn();
|
||||
|
||||
device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>(
|
||||
descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_gemm(a, b, c_host, layout);
|
||||
|
||||
check_error(c_host, c_device);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", c_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
12
host/host_tensor/include/gemm_common.hpp
Normal file
12
host/host_tensor/include/gemm_common.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#ifndef GEMM_COMMON_HPP
|
||||
#define GEMM_COMMON_HPP
|
||||
|
||||
enum GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
#endif
|
||||
87
host/host_tensor/include/host_gemm.hpp
Normal file
87
host/host_tensor/include/host_gemm.hpp
Normal file
@@ -0,0 +1,87 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
|
||||
template <typename AType, typename BType, typename CType>
|
||||
void host_gemm(const Tensor<AType>& a,
|
||||
const Tensor<BType>& b,
|
||||
Tensor<CType>& c,
|
||||
const GemmMatrixLayout layout)
|
||||
{
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
auto f_mk_nk_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
auto f_km_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
auto f_km_nk_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user