Support 1 dimension

This commit is contained in:
rocking
2022-05-17 20:50:03 +08:00
parent 0d26477a86
commit 4af77e1f0e
4 changed files with 151 additions and 4 deletions

View File

@@ -67,6 +67,11 @@ int main()
ck::index_t N = 1024;
ck::index_t Stride = 1024;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
@@ -74,8 +79,7 @@ int main()
Tensor<ABDataType> a_m_n(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<ABDataType> b_n(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1}));
Tensor<ABDataType> b_n(f_host_tensor_descriptor1d(N, 1));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, Stride));