Fix wrong answer when interleaved=false

This commit is contained in:
PoYen, Chen
2024-07-10 12:50:00 +00:00
parent 8c733fb3be
commit 52da00acd6

View File

@@ -36,13 +36,14 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
self(i) = input_bhsd(i);
return;
}
assert(i_d < rotary_dim);
const index_t i_s = i[2];
const ComputeDataType cos = type_convert<ComputeDataType>(
interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % rotary_dim));
interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % cos_sd.get_length(1)));
const ComputeDataType sin = type_convert<ComputeDataType>(
interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % rotary_dim));
interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % sin_sd.get_length(1)));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];
const index_t i_h = i[1];