external api for gemm + layernorm (#285)

* Extract base class for elementwise

* Refactor interface of DeviceGemmReduce. Do not use tuple in interface

* [What] Rename d into reduce in gemm + reduction related code
[Why] Prepare to add d term for add

* Unify base class of gemm + reduce and gemm + bias + add + reduce

* 1. Rename gemm_bias_add_reduce for external api
 2. Refine cmake

* Add normalize device operation

* [What] Reorder the argument
[Why] Because d0 is also the input of c.

* Add type string

* Add example of gemm_bias_add_layernorm  via external api

* Refactor example code

* clang-format

* Fix compile error

* clang-format

* Add external api for gemm_add_add_layernorm and normalize

* Add client example

* clang-format
This commit is contained in:
rocking5566
2022-06-28 03:25:10 +08:00
committed by GitHub
parent aebd211c36
commit 12235112a1
47 changed files with 2577 additions and 1946 deletions

View File

@@ -81,16 +81,22 @@ int main()
a_device_buf.ToDevice(a.mData.data());
b_device_buf.ToDevice(b.mData.data());
std::array<const void*, 2> input = {a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()};
std::vector<ck::index_t> b_strides{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()};
std::vector<ck::index_t> c_strides{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
std::vector<ck::index_t>{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()},
std::vector<ck::index_t>{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()},
Add{});
auto argument =
broadcastAdd.MakeArgumentPointer(input,
output,
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
{{a_strides}, b_strides},
{c_strides},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{