diff --git a/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp b/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp index 712226a8c0..542c9df9ce 100644 --- a/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp +++ b/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp @@ -28,12 +28,46 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor const index_t rotary_dim = cos_sd.get_length(1) * 2; assert(rotary_dim <= input_bhsd.get_length(3)); - (void)rotary_dim; - (void)input_bhsd; - (void)sin_sd; - (void)cos_sd; - (void)interleaved; - (void)output_bhsd; + + output_bhsd.ForEach([&](auto& self, auto i) { + const index_t i_d = i[3]; + if(rotary_dim <= i_d) + { + self(i) = input_bhsd(i); + return; + } + + const index_t i_s = i[2]; + + const ComputeDataType cos = + (interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % rotary_dim)); + const ComputeDataType sin = + (interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % rotary_dim)); + const ComputeDataType half_rotated_input = [&] { + const index_t i_b = i[0]; + const index_t i_h = i[1]; + + const index_t hdim = input_bhsd.get_length(3); + const index_t half_hdim = hdim / 2; + + if(interleaved) + { + const index_t pos = (i_d < half_hdim ? (i_d * 2 + 1) : (i_d - half_hdim) * 2); + const ComputeDataType sign = (i_d < half_hdim ? 1 : -1); + return sign * type_convert(input_bhsd(i_b, i_h, i_s, pos)); + } + else + { + const index_t pos = (i_d + half_hdim) % hdim; + const ComputeDataType sign = (pos < half_hdim ? 1 : -1); + return sign * type_convert(input_bhsd(i_b, i_h, i_s, pos)); + } + }(); + ComputeDataType result = + type_convert(input_bhsd(i)) * cos + half_rotated_input * sin; + + self(i) = type_convert(result); + }); } } // namespace ck_tile