diff --git a/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp b/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp index 601336dabe..926ff1a595 100644 --- a/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp +++ b/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp @@ -36,13 +36,14 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor self(i) = input_bhsd(i); return; } + assert(i_d < rotary_dim); const index_t i_s = i[2]; const ComputeDataType cos = type_convert( - interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % rotary_dim)); + interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % cos_sd.get_length(1))); const ComputeDataType sin = type_convert( - interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % rotary_dim)); + interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % sin_sd.get_length(1))); const ComputeDataType half_rotated_input = [&] { const index_t i_b = i[0]; const index_t i_h = i[1];