Improve external interface for GEMM and GEMM+add+add+fastgelu (#311)

* interface for GEMM and GEMM+add+add+fastgelu

* rename namespace

* instance factory

* fix build

* fix build; add GEMM client example

* clean
This commit is contained in:
Chao Liu
2022-06-30 22:11:00 -05:00
committed by GitHub
parent fa9a0a5cfb
commit 0dcb3496cf
259 changed files with 2915 additions and 2969 deletions

View File

@@ -11,6 +11,8 @@
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
@@ -27,30 +29,6 @@ enum struct GemmMatrixLayout
KM_NK_MN, // 3
};
using DeviceGemmSplitKNoOpPtr = ck::tensor_operation::device::DeviceGemmSplitKPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
@@ -82,6 +60,11 @@ struct gemmArgs
int test_gemm(const gemmArgs& args)
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
bool a_row_major, b_row_major, c_row_major;
switch(args.layout)
@@ -152,64 +135,79 @@ int test_gemm(const gemmArgs& args)
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances
std::vector<DeviceGemmSplitKNoOpPtr> gemm_ptrs;
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool success = false;
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
float,
float,
float,
PassThrough,
PassThrough,
PassThrough>;
const auto gemm_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
static_cast<float*>(b_device_buf.GetDeviceBuffer()),
static_cast<float*>(c_device_buf.GetDeviceBuffer()),
args.M,
args.N,
args.K,
args.StrideA,
args.StrideB,
args.StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
args.KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get());
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(!check_out(c_m_n_host_result, c_m_n_device_result))
{
success = false;
break;
}
success = true;
}
}
return success;
};
bool success = false;
if(args.layout == GemmMatrixLayout::MK_KN_MN)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
success = test(Row{}, Row{}, Row{});
}
else if(args.layout == GemmMatrixLayout::MK_NK_MN)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
success = test(Row{}, Col{}, Row{});
}
else if(args.layout == GemmMatrixLayout::KM_KN_MN)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
success = test(Col{}, Row{}, Row{});
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
success = test(Col{}, Col{}, Row{});
}
bool success = false;
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
static_cast<float*>(b_device_buf.GetDeviceBuffer()),
static_cast<float*>(c_device_buf.GetDeviceBuffer()),
args.M,
args.N,
args.K,
args.StrideA,
args.StrideB,
args.StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
args.KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get());
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(!check_out(c_m_n_host_result, c_m_n_device_result))
{
success = false;
break;
}
success = true;
}
}
auto error_code = 0;
if(success)
{