mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
external api for gemm + layernorm (#285)
* Extract base class for elementwise * Refactor interface of DeviceGemmReduce. Do not use tuple in interface * [What] Rename d into reduce in gemm + reduction related code [Why] Prepare to add d term for add * Unify base class of gemm + reduce and gemm + bias + add + reduce * 1. Rename gemm_bias_add_reduce for external api 2. Refine cmake * Add normalize device operation * [What] Reorder the argument [Why] Because d0 is also the input of c. * Add type string * Add example of gemm_bias_add_layernorm via external api * Refactor example code * clang-format * Fix compile error * clang-format * Add external api for gemm_add_add_layernorm and normalize * Add client example * clang-format
This commit is contained in:
@@ -79,15 +79,17 @@ int main()
|
||||
a_m_device_buf.ToDevice(a_m.mData.data());
|
||||
b_m_device_buf.ToDevice(b_m.mData.data());
|
||||
|
||||
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::vector<ck::index_t> a_strides = {1};
|
||||
std::vector<ck::index_t> b_strides = {1};
|
||||
std::vector<ck::index_t> c_strides = {1};
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_device_buf.GetDeviceBuffer(),
|
||||
c_m_device_buf.GetDeviceBuffer(),
|
||||
{M},
|
||||
{1},
|
||||
{1},
|
||||
{1},
|
||||
Add{});
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(
|
||||
input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user