Support providing invalid element for tensor view

This commit is contained in:
PoYen, Chen
2024-06-12 02:52:07 +00:00
parent b994668714
commit ff866f6bb6

View File

@@ -222,6 +222,36 @@ make_naive_tensor_view(DataType* p,
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
memory_operation_enum DstInMemOp = memory_operation_enum::set,
typename DataType,
typename... Lengths,
typename... Strides,
typename X,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides) &&
std::is_same_v<remove_cvref_t<DataType>, remove_cvref_t<X>>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view(DataType* p,
const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
X invalid_element_value,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
auto desc = make_naive_tensor_descriptor(lengths,
strides,
number<GuaranteedLastDimensionVectorLength>{},
number<GuaranteedLastDimensionVectorStride>{});
auto buffer_view = make_buffer_view<BufferAddressSpace>(
p, desc.get_element_space_size(), invalid_element_value);
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Lengths,