diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index cda2fb0bb5..0ec975441f 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -236,12 +236,13 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, namespace detail { template -CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths, - const Strides& strides, - number i, - AccOld acc_old) +CK_TILE_HOST_DEVICE constexpr long_index_t calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + number i, + AccOld acc_old) { - auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i]; + long_index_t acc_new = acc_old + static_cast(lengths[i] - number<1>{}) * + static_cast(strides[i]); if constexpr(i.value < Lengths::size() - 1) { @@ -287,8 +288,12 @@ make_naive_tensor_descriptor(const tuple& lengths, constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; - const auto element_space_size = + const long_index_t element_space_size_long = detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{}); + constexpr long_index_t element_space_size_clamp_value = + static_cast(std::numeric_limits::max()); + const index_t element_space_size = + static_cast(std::min(element_space_size_long, element_space_size_clamp_value)); using GuaranteedVectorLengths = typename sequence_merge::type, @@ -323,8 +328,12 @@ make_naive_tensor_descriptor_with_offset(const tuple& lengths, number = number<-1>{}) { const auto desc_0 = [&]() { - const auto element_space_size = detail::calculate_element_space_size_impl( + const auto element_space_size_long = detail::calculate_element_space_size_impl( lengths, strides, number<0>{}, long_number<1>{}); + constexpr long_index_t element_space_size_clamp_value = + static_cast(std::numeric_limits::max()); + const index_t element_space_size = + static_cast(std::min(element_space_size_long, element_space_size_clamp_value)); const auto transforms = make_tuple(make_offset_transform(element_space_size, os));