mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 05:55:39 +00:00
Merge remote-tracking branch 'origin/develop' into myamlak/cgemm
This commit is contained in:
@@ -1,10 +1,5 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include <math.h>
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
@@ -13,7 +8,6 @@
|
||||
|
||||
#include "device_tensor.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
|
||||
#include "device_binary_elementwise.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
@@ -26,7 +20,7 @@ using EltwiseComputeDataType = F32;
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
|
||||
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 2, 8>;
|
||||
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 2, 8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
@@ -37,6 +31,8 @@ template <typename HostTensorA,
|
||||
void host_broadcast2D(
|
||||
HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0, 0))>;
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
@@ -53,7 +49,7 @@ void host_broadcast2D(
|
||||
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
|
||||
functor(Cmn, Amn, Bm);
|
||||
}
|
||||
C(m, n) = static_cast<ComputeDataType>(Cmn);
|
||||
C(m, n) = static_cast<ctype>(Cmn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include <math.h>
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
@@ -13,7 +8,6 @@
|
||||
|
||||
#include "device_tensor.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
|
||||
#include "device_binary_elementwise.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
@@ -26,7 +20,7 @@ using EltwiseComputeDataType = F32;
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
|
||||
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 1, 8>;
|
||||
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 1, 8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
@@ -36,13 +30,15 @@ template <typename HostTensorA,
|
||||
void host_elementwise1D(
|
||||
HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0))>;
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ComputeDataType Am = static_cast<ComputeDataType>(A(m));
|
||||
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
|
||||
ComputeDataType Cm = 0;
|
||||
functor(Cm, Am, Bm);
|
||||
C(m) = static_cast<ComputeDataType>(Cm);
|
||||
C(m) = static_cast<ctype>(Cm);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include <math.h>
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "host_utility.hpp"
|
||||
|
||||
#include "device_tensor.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
|
||||
#include "device_binary_elementwise.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
@@ -27,7 +21,7 @@ using EltwiseComputeDataType = F32;
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
|
||||
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 4, 8>;
|
||||
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 4, 8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
@@ -40,6 +34,8 @@ void host_elementwise4D(HostTensorC& C,
|
||||
const std::vector<std::size_t>& shape,
|
||||
Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0, 0, 0, 0))>;
|
||||
|
||||
for(std::size_t n = 0; n < shape[0]; ++n)
|
||||
for(std::size_t c = 0; c < shape[1]; ++c)
|
||||
for(std::size_t h = 0; h < shape[2]; ++h)
|
||||
@@ -49,7 +45,7 @@ void host_elementwise4D(HostTensorC& C,
|
||||
ComputeDataType b_val = static_cast<ComputeDataType>(B(n, c, h, w));
|
||||
ComputeDataType c_val = 0;
|
||||
functor(c_val, a_val, b_val);
|
||||
C(n, c, h, w) = static_cast<ComputeDataType>(c_val);
|
||||
C(n, c, h, w) = static_cast<ctype>(c_val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,14 +71,15 @@ int main()
|
||||
b_m_device_buf.ToDevice(b_m.mData.data());
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_device_buf.GetDeviceBuffer(),
|
||||
c_m_device_buf.GetDeviceBuffer(),
|
||||
ck::to_int_vector(nchw),
|
||||
ck::to_int_vector(a_m.mDesc.GetStrides()),
|
||||
ck::to_int_vector(b_m.mDesc.GetStrides()),
|
||||
ck::to_int_vector(c_m.mDesc.GetStrides()),
|
||||
Add{});
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(
|
||||
a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_device_buf.GetDeviceBuffer(),
|
||||
c_m_device_buf.GetDeviceBuffer(),
|
||||
ck::convert_vector_element_type<std::size_t, ck::index_t>(nchw),
|
||||
ck::convert_vector_element_type<std::size_t, ck::index_t>(a_m.mDesc.GetStrides()),
|
||||
ck::convert_vector_element_type<std::size_t, ck::index_t>(b_m.mDesc.GetStrides()),
|
||||
ck::convert_vector_element_type<std::size_t, ck::index_t>(c_m.mDesc.GetStrides()),
|
||||
Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user