Merge remote-tracking branch 'origin/develop' into myamlak/cgemm

This commit is contained in:
myamlak
2022-05-19 12:43:59 +00:00
8 changed files with 119 additions and 106 deletions

View File

@@ -1,10 +1,5 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <math.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
@@ -13,7 +8,6 @@
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using F16 = ck::half_t;
@@ -26,7 +20,7 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 2, 8>;
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 2, 8>;
template <typename HostTensorA,
typename HostTensorB,
@@ -37,6 +31,8 @@ template <typename HostTensorA,
void host_broadcast2D(
HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor)
{
using ctype = ck::remove_reference_t<decltype(C(0, 0))>;
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
@@ -53,7 +49,7 @@ void host_broadcast2D(
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
functor(Cmn, Amn, Bm);
}
C(m, n) = static_cast<ComputeDataType>(Cmn);
C(m, n) = static_cast<ctype>(Cmn);
}
}
}

View File

@@ -1,10 +1,5 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <math.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
@@ -13,7 +8,6 @@
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using F16 = ck::half_t;
@@ -26,7 +20,7 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 1, 8>;
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 1, 8>;
template <typename HostTensorA,
typename HostTensorB,
@@ -36,13 +30,15 @@ template <typename HostTensorA,
void host_elementwise1D(
HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor)
{
using ctype = ck::remove_reference_t<decltype(C(0))>;
for(int m = 0; m < M; ++m)
{
ComputeDataType Am = static_cast<ComputeDataType>(A(m));
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
ComputeDataType Cm = 0;
functor(Cm, Am, Bm);
C(m) = static_cast<ComputeDataType>(Cm);
C(m) = static_cast<ctype>(Cm);
}
}

View File

@@ -1,20 +1,14 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <math.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_reduce_util.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_utility.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using F16 = ck::half_t;
@@ -27,7 +21,7 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<F16, F16, CDataType, EltwiseComputeDataType, Add, 4, 8>;
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 4, 8>;
template <typename HostTensorA,
typename HostTensorB,
@@ -40,6 +34,8 @@ void host_elementwise4D(HostTensorC& C,
const std::vector<std::size_t>& shape,
Functor functor)
{
using ctype = ck::remove_reference_t<decltype(C(0, 0, 0, 0))>;
for(std::size_t n = 0; n < shape[0]; ++n)
for(std::size_t c = 0; c < shape[1]; ++c)
for(std::size_t h = 0; h < shape[2]; ++h)
@@ -49,7 +45,7 @@ void host_elementwise4D(HostTensorC& C,
ComputeDataType b_val = static_cast<ComputeDataType>(B(n, c, h, w));
ComputeDataType c_val = 0;
functor(c_val, a_val, b_val);
C(n, c, h, w) = static_cast<ComputeDataType>(c_val);
C(n, c, h, w) = static_cast<ctype>(c_val);
}
}
@@ -75,14 +71,15 @@ int main()
b_m_device_buf.ToDevice(b_m.mData.data());
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer(),
c_m_device_buf.GetDeviceBuffer(),
ck::to_int_vector(nchw),
ck::to_int_vector(a_m.mDesc.GetStrides()),
ck::to_int_vector(b_m.mDesc.GetStrides()),
ck::to_int_vector(c_m.mDesc.GetStrides()),
Add{});
auto argument = broadcastAdd.MakeArgumentPointer(
a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer(),
c_m_device_buf.GetDeviceBuffer(),
ck::convert_vector_element_type<std::size_t, ck::index_t>(nchw),
ck::convert_vector_element_type<std::size_t, ck::index_t>(a_m.mDesc.GetStrides()),
ck::convert_vector_element_type<std::size_t, ck::index_t>(b_m.mDesc.GetStrides()),
ck::convert_vector_element_type<std::size_t, ck::index_t>(c_m.mDesc.GetStrides()),
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{