diff --git a/composable_kernel/include/tensor_description/tensor_coordinate.hpp b/composable_kernel/include/tensor_description/tensor_coordinate.hpp index 1a2774b589..223d0d5bed 100644 --- a/composable_kernel/include/tensor_description/tensor_coordinate.hpp +++ b/composable_kernel/include/tensor_description/tensor_coordinate.hpp @@ -166,7 +166,7 @@ struct MergedTensorCoordinate // do carry check in reversed order, starting from lowest dimension // don't check the highest dimension - static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) { + static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) { constexpr index_t i = ndim_partial_original - 1 - IReverse; if(carry) @@ -182,6 +182,12 @@ struct MergedTensorCoordinate carry = true; } }); + + // highest dimension + if(carry) + { + ++partial_original_id(0); + } }).Else([&](auto) { // shift up multi-id to avoid unsigned integer underflow during intermediate // calculations. After the shift, should have new_multi_id[...] >= 1 @@ -192,7 +198,7 @@ struct MergedTensorCoordinate // do borrow check in reversed order, starting from lowest dimension // don't check the highest dimension - static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) { + static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) { constexpr index_t i = ndim_partial_original - 1 - IReverse; if(borrow) @@ -209,6 +215,12 @@ struct MergedTensorCoordinate } }); + // highest dimension + if(borrow) + { + --partial_original_id(0); + } + // shift back down multi-id // here, should have new_multi_id[...] >= GetLengths() partial_original_id = partial_original_id - partial_original_desc.GetLengths();