[rocm-libraries] ROCm/rocm-libraries#6168 (commit 2968835)

[CK][CK Tile] Clamp element space size to max int32 value
 (#6168)

## Motivation

Fix oob check by clamping element space size to avoid overflow when
tensor is larger than 2GB.

## Technical Details

- It is possible that tensor could be larger than 2GB but offsets no, so
element space size must be clamped to 2GB if value is larger.

## Test Plan

CI

## Test Result

Pending

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

https://github.com/ROCm/composable_kernel/issues/3722

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
This commit is contained in:
Bartłomiej Kocot
2026-04-20 15:33:18 +00:00
committed by assistant-librarian[bot]
parent d4236de1ba
commit 60ff5693c4

View File

@@ -236,12 +236,13 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
namespace detail {
template <typename Lengths, typename Strides, index_t I, typename AccOld>
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
number<I> i,
AccOld acc_old)
CK_TILE_HOST_DEVICE constexpr long_index_t calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
number<I> i,
AccOld acc_old)
{
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
long_index_t acc_new = acc_old + static_cast<long_index_t>(lengths[i] - number<1>{}) *
static_cast<long_index_t>(strides[i]);
if constexpr(i.value < Lengths::size() - 1)
{
@@ -287,8 +288,12 @@ make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size =
const long_index_t element_space_size_long =
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
constexpr long_index_t element_space_size_clamp_value =
static_cast<long_index_t>(std::numeric_limits<index_t>::max());
const index_t element_space_size =
static_cast<index_t>(std::min(element_space_size_long, element_space_size_clamp_value));
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
@@ -323,8 +328,12 @@ make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
const auto desc_0 = [&]() {
const auto element_space_size = detail::calculate_element_space_size_impl(
const auto element_space_size_long = detail::calculate_element_space_size_impl(
lengths, strides, number<0>{}, long_number<1>{});
constexpr long_index_t element_space_size_clamp_value =
static_cast<long_index_t>(std::numeric_limits<index_t>::max());
const index_t element_space_size =
static_cast<index_t>(std::min(element_space_size_long, element_space_size_clamp_value));
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));