mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Support 1 dimension
This commit is contained in:
@@ -67,6 +67,11 @@ int main()
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t Stride = 1024;
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
@@ -74,8 +79,7 @@ int main()
|
||||
|
||||
Tensor<ABDataType> a_m_n(f_host_tensor_descriptor2d(M, N, Stride));
|
||||
|
||||
Tensor<ABDataType> b_n(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
|
||||
std::vector<std::size_t>({1}));
|
||||
Tensor<ABDataType> b_n(f_host_tensor_descriptor1d(N, 1));
|
||||
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, Stride));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user