mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Reduce input/output dimensions
This commit is contained in:
@@ -16,29 +16,29 @@ namespace detail {
|
||||
}
|
||||
|
||||
template <typename ComputeDataType, typename DataType>
|
||||
CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>& input_bhsd,
|
||||
CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>& input_bsd,
|
||||
const HostTensor<DataType>& cos_sd,
|
||||
const HostTensor<DataType>& sin_sd,
|
||||
bool interleaved,
|
||||
HostTensor<DataType>& output_bhsd)
|
||||
HostTensor<DataType>& output_bsd)
|
||||
{
|
||||
assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
|
||||
assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
|
||||
cos_sd.get_length(1) == sin_sd.get_length(1));
|
||||
|
||||
const index_t rotary_dim = cos_sd.get_length(1) * 2;
|
||||
assert(static_cast<std::size_t>(rotary_dim) <= input_bhsd.get_length(3));
|
||||
assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
|
||||
|
||||
output_bhsd.ForEach([&](auto& self, auto i) {
|
||||
const index_t i_d = i[3];
|
||||
output_bsd.ForEach([&](auto& self, auto i) {
|
||||
const index_t i_d = i[2];
|
||||
if(rotary_dim <= i_d)
|
||||
{
|
||||
self(i) = input_bhsd(i);
|
||||
self(i) = input_bsd(i);
|
||||
return;
|
||||
}
|
||||
assert(i_d < rotary_dim);
|
||||
|
||||
const index_t i_s = i[2];
|
||||
const index_t i_s = i[1];
|
||||
|
||||
const ComputeDataType cos = type_convert<ComputeDataType>(
|
||||
interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % cos_sd.get_length(1)));
|
||||
@@ -46,9 +46,8 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
|
||||
interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % sin_sd.get_length(1)));
|
||||
const ComputeDataType half_rotated_input = [&] {
|
||||
const index_t i_b = i[0];
|
||||
const index_t i_h = i[1];
|
||||
|
||||
const index_t hdim = input_bhsd.get_length(3);
|
||||
const index_t hdim = input_bsd.get_length(2);
|
||||
const index_t half_hdim = hdim / 2;
|
||||
|
||||
if(interleaved)
|
||||
@@ -56,17 +55,17 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
|
||||
const bool is_even = (i_d % 2 == 0);
|
||||
const index_t pos = i_d + (is_even ? 1 : -1);
|
||||
const ComputeDataType sign = (is_even ? -1 : 1);
|
||||
return sign * type_convert<ComputeDataType>(input_bhsd(i_b, i_h, i_s, pos));
|
||||
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t pos = (i_d + half_hdim) % hdim;
|
||||
const ComputeDataType sign = (pos < half_hdim ? 1 : -1);
|
||||
return sign * type_convert<ComputeDataType>(input_bhsd(i_b, i_h, i_s, pos));
|
||||
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
|
||||
}
|
||||
}();
|
||||
ComputeDataType result =
|
||||
type_convert<ComputeDataType>(input_bhsd(i)) * cos + half_rotated_input * sin;
|
||||
type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
|
||||
|
||||
self(i) = type_convert<DataType>(result);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user