mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
* enabled atomic add in tensor copy * added gridwise GEMM * added backward data conv using GEMM + atomic * added backward data conv using GEMM, no atomic
29 lines
969 B
C++
29 lines
969 B
C++
#pragma once
|
|
#include "tensor.hpp"
|
|
#include "common_header.hpp"
|
|
#include "ConstantTensorDescriptor_deprecated.hpp"
|
|
#include "tensor_descriptor.hpp"
|
|
|
|
template <typename ConstTensorDesc, std::size_t... Is>
|
|
auto make_TensorDescriptor_impl(ConstTensorDesc, std::integer_sequence<std::size_t, Is...>)
|
|
{
|
|
std::initializer_list<std::size_t> lengths = {ConstTensorDesc::GetLengths()[Is]...};
|
|
std::initializer_list<std::size_t> strides = {ConstTensorDesc::GetStrides()[Is]...};
|
|
|
|
return TensorDescriptor(lengths, strides);
|
|
}
|
|
|
|
template <typename ConstTensorDesc>
|
|
auto make_TensorDescriptor(ConstTensorDesc)
|
|
{
|
|
return make_TensorDescriptor_impl(
|
|
ConstTensorDesc{},
|
|
std::make_integer_sequence<std::size_t, ConstTensorDesc::GetNumOfDimension()>{});
|
|
}
|
|
|
|
template <typename ConstTensorDesc>
|
|
void ostream_ConstantTensorDescriptor(ConstTensorDesc, std::ostream& os = std::cout)
|
|
{
|
|
ostream_TensorDescriptor(make_TensorDescriptor(ConstTensorDesc{}), os);
|
|
}
|