mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Finish reference_rotary_position_embedding() impl
This commit is contained in:
@@ -28,12 +28,46 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
|
||||
|
||||
const index_t rotary_dim = cos_sd.get_length(1) * 2;
|
||||
assert(rotary_dim <= input_bhsd.get_length(3));
|
||||
(void)rotary_dim;
|
||||
(void)input_bhsd;
|
||||
(void)sin_sd;
|
||||
(void)cos_sd;
|
||||
(void)interleaved;
|
||||
(void)output_bhsd;
|
||||
|
||||
output_bhsd.ForEach([&](auto& self, auto i) {
|
||||
const index_t i_d = i[3];
|
||||
if(rotary_dim <= i_d)
|
||||
{
|
||||
self(i) = input_bhsd(i);
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t i_s = i[2];
|
||||
|
||||
const ComputeDataType cos =
|
||||
(interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % rotary_dim));
|
||||
const ComputeDataType sin =
|
||||
(interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % rotary_dim));
|
||||
const ComputeDataType half_rotated_input = [&] {
|
||||
const index_t i_b = i[0];
|
||||
const index_t i_h = i[1];
|
||||
|
||||
const index_t hdim = input_bhsd.get_length(3);
|
||||
const index_t half_hdim = hdim / 2;
|
||||
|
||||
if(interleaved)
|
||||
{
|
||||
const index_t pos = (i_d < half_hdim ? (i_d * 2 + 1) : (i_d - half_hdim) * 2);
|
||||
const ComputeDataType sign = (i_d < half_hdim ? 1 : -1);
|
||||
return sign * type_convert<ComputeDataType>(input_bhsd(i_b, i_h, i_s, pos));
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t pos = (i_d + half_hdim) % hdim;
|
||||
const ComputeDataType sign = (pos < half_hdim ? 1 : -1);
|
||||
return sign * type_convert<ComputeDataType>(input_bhsd(i_b, i_h, i_s, pos));
|
||||
}
|
||||
}();
|
||||
ComputeDataType result =
|
||||
type_convert<ComputeDataType>(input_bhsd(i)) * cos + half_rotated_input * sin;
|
||||
|
||||
self(i) = type_convert<DataType>(result);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user