Reduce input/output dimensions

This commit is contained in:
PoYen, Chen
2024-07-12 06:49:43 +00:00
parent 3183b68921
commit ff75eff3bf

View File

@@ -16,29 +16,29 @@ namespace detail {
}
template <typename ComputeDataType, typename DataType>
CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>& input_bhsd,
CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>& input_bsd,
const HostTensor<DataType>& cos_sd,
const HostTensor<DataType>& sin_sd,
bool interleaved,
HostTensor<DataType>& output_bhsd)
HostTensor<DataType>& output_bsd)
{
assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
cos_sd.get_length(1) == sin_sd.get_length(1));
const index_t rotary_dim = cos_sd.get_length(1) * 2;
assert(static_cast<std::size_t>(rotary_dim) <= input_bhsd.get_length(3));
assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
output_bhsd.ForEach([&](auto& self, auto i) {
const index_t i_d = i[3];
output_bsd.ForEach([&](auto& self, auto i) {
const index_t i_d = i[2];
if(rotary_dim <= i_d)
{
self(i) = input_bhsd(i);
self(i) = input_bsd(i);
return;
}
assert(i_d < rotary_dim);
const index_t i_s = i[2];
const index_t i_s = i[1];
const ComputeDataType cos = type_convert<ComputeDataType>(
interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % cos_sd.get_length(1)));
@@ -46,9 +46,8 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % sin_sd.get_length(1)));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];
const index_t i_h = i[1];
const index_t hdim = input_bhsd.get_length(3);
const index_t hdim = input_bsd.get_length(2);
const index_t half_hdim = hdim / 2;
if(interleaved)
@@ -56,17 +55,17 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
const bool is_even = (i_d % 2 == 0);
const index_t pos = i_d + (is_even ? 1 : -1);
const ComputeDataType sign = (is_even ? -1 : 1);
return sign * type_convert<ComputeDataType>(input_bhsd(i_b, i_h, i_s, pos));
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
else
{
const index_t pos = (i_d + half_hdim) % hdim;
const ComputeDataType sign = (pos < half_hdim ? 1 : -1);
return sign * type_convert<ComputeDataType>(input_bhsd(i_b, i_h, i_s, pos));
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
}();
ComputeDataType result =
type_convert<ComputeDataType>(input_bhsd(i)) * cos + half_rotated_input * sin;
type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
self(i) = type_convert<DataType>(result);
});