external api for gemm + layernorm (#285)

* Extract base class for elementwise

* Refactor interface of DeviceGemmReduce. Do not use tuple in interface

* [What] Rename d into reduce in gemm + reduction related code
[Why] Prepare to add d term for add

* Unify base class of gemm + reduce and gemm + bias + add + reduce

* 1. Rename gemm_bias_add_reduce for external api
 2. Refine cmake

* Add normalize device operation

* [What] Reorder the argument
[Why] Because d0 is also the input of c.

* Add type string

* Add example of gemm_bias_add_layernorm  via external api

* Refactor example code

* clang-format

* Fix compile error

* clang-format

* Add external api for gemm_add_add_layernorm and normalize

* Add client example

* clang-format
This commit is contained in:
rocking5566
2022-06-28 03:25:10 +08:00
committed by GitHub
parent aebd211c36
commit 12235112a1
47 changed files with 2577 additions and 1946 deletions

View File

@@ -79,15 +79,17 @@ int main()
a_m_device_buf.ToDevice(a_m.mData.data());
b_m_device_buf.ToDevice(b_m.mData.data());
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1};
std::vector<ck::index_t> b_strides = {1};
std::vector<ck::index_t> c_strides = {1};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer(),
c_m_device_buf.GetDeviceBuffer(),
{M},
{1},
{1},
{1},
Add{});
auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{