Finish reference_rotary_position_embedding() impl

This commit is contained in:
PoYen, Chen
2024-07-10 09:16:54 +00:00
parent f2d28e8ab4
commit 9d29311da0

View File

@@ -28,12 +28,46 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
const index_t rotary_dim = cos_sd.get_length(1) * 2;
assert(rotary_dim <= input_bhsd.get_length(3));
(void)rotary_dim;
(void)input_bhsd;
(void)sin_sd;
(void)cos_sd;
(void)interleaved;
(void)output_bhsd;
output_bhsd.ForEach([&](auto& self, auto i) {
const index_t i_d = i[3];
if(rotary_dim <= i_d)
{
self(i) = input_bhsd(i);
return;
}
const index_t i_s = i[2];
const ComputeDataType cos =
(interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % rotary_dim));
const ComputeDataType sin =
(interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % rotary_dim));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];
const index_t i_h = i[1];
const index_t hdim = input_bhsd.get_length(3);
const index_t half_hdim = hdim / 2;
if(interleaved)
{
const index_t pos = (i_d < half_hdim ? (i_d * 2 + 1) : (i_d - half_hdim) * 2);
const ComputeDataType sign = (i_d < half_hdim ? 1 : -1);
return sign * type_convert<ComputeDataType>(input_bhsd(i_b, i_h, i_s, pos));
}
else
{
const index_t pos = (i_d + half_hdim) % hdim;
const ComputeDataType sign = (pos < half_hdim ? 1 : -1);
return sign * type_convert<ComputeDataType>(input_bhsd(i_b, i_h, i_s, pos));
}
}();
ComputeDataType result =
type_convert<ComputeDataType>(input_bhsd(i)) * cos + half_rotated_input * sin;
self(i) = type_convert<DataType>(result);
});
}
} // namespace ck_tile