Files
composable_kernel/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
Chao Liu f2523a771e adding implicit gemm v4r2
[ROCm/composable_kernel commit: 923578a389]
2019-07-05 15:35:11 -05:00

72 lines
2.3 KiB
C++

#ifndef CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace ck {
template <index_t NRow_, index_t NCol_, index_t RowStride_>
struct ConstantMatrixDescriptor
{
__host__ __device__ constexpr ConstantMatrixDescriptor()
{
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
}
__host__ __device__ static constexpr index_t NRow() { return NRow_; }
__host__ __device__ static constexpr index_t NCol() { return NCol_; }
__host__ __device__ static constexpr index_t RowStride() { return RowStride_; }
__host__ __device__ static constexpr auto GetLengths() { return Sequence<NRow_, NCol_>{}; }
__host__ __device__ static constexpr index_t GetElementSize() { return NRow_ * NCol_; }
__host__ __device__ static constexpr index_t GetElementSpace() { return NRow_ * RowStride_; }
__host__ __device__ static index_t GetOffsetFromMultiIndex(index_t irow, index_t icol)
{
return irow * RowStride_ + icol;
}
template <index_t SubNRow, index_t SubNCol>
__host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>)
{
return ConstantMatrixDescriptor<SubNRow, SubNCol, RowStride_>{};
}
};
template <index_t NRow, index_t NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor_packed(Number<NRow>, Number<NCol>)
{
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
}
template <index_t NRow, index_t NCol, index_t RowStride>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
}
template <index_t NRow, index_t NCol, index_t RowStride>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(ConstantTensorDescriptor<Sequence<NRow, NCol>, Sequence<RowStride, 1>>
{
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
}
template <class TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{
printf(
"%s NRow %u NCol %u RowStride %u\n", s, TDesc::NRow(), TDesc::NCol(), TDesc::RowStride());
}
} // namespace ck
#endif