Fix wrong answer when interleaved=true

This commit is contained in:
PoYen, Chen
2024-07-11 00:26:18 +00:00
parent 52da00acd6
commit ee365bbc66

View File

@@ -53,8 +53,9 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
if(interleaved)
{
const index_t pos = (i_d < half_hdim ? (i_d * 2 + 1) : (i_d - half_hdim) * 2);
const ComputeDataType sign = (i_d < half_hdim ? 1 : -1);
const bool is_even = (i_d % 2 == 0);
const index_t pos = i_d + (is_even ? 1 : -1);
const ComputeDataType sign = (is_even ? -1 : 1);
return sign * type_convert<ComputeDataType>(input_bhsd(i_b, i_h, i_s, pos));
}
else