diff --git a/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp b/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp index 07e5f42ec5..91d4467fee 100644 --- a/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp +++ b/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp @@ -16,29 +16,29 @@ namespace detail { } template -CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor& input_bhsd, +CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor& input_bsd, const HostTensor& cos_sd, const HostTensor& sin_sd, bool interleaved, - HostTensor& output_bhsd) + HostTensor& output_bsd) { assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2); assert(cos_sd.get_length(0) == sin_sd.get_length(0) && cos_sd.get_length(1) == sin_sd.get_length(1)); const index_t rotary_dim = cos_sd.get_length(1) * 2; - assert(static_cast(rotary_dim) <= input_bhsd.get_length(3)); + assert(static_cast(rotary_dim) <= input_bsd.get_length(2)); - output_bhsd.ForEach([&](auto& self, auto i) { - const index_t i_d = i[3]; + output_bsd.ForEach([&](auto& self, auto i) { + const index_t i_d = i[2]; if(rotary_dim <= i_d) { - self(i) = input_bhsd(i); + self(i) = input_bsd(i); return; } assert(i_d < rotary_dim); - const index_t i_s = i[2]; + const index_t i_s = i[1]; const ComputeDataType cos = type_convert( interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % cos_sd.get_length(1))); @@ -46,9 +46,8 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % sin_sd.get_length(1))); const ComputeDataType half_rotated_input = [&] { const index_t i_b = i[0]; - const index_t i_h = i[1]; - const index_t hdim = input_bhsd.get_length(3); + const index_t hdim = input_bsd.get_length(2); const index_t half_hdim = hdim / 2; if(interleaved) @@ -56,17 +55,17 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor const bool is_even = (i_d % 2 == 0); const index_t pos = i_d + (is_even ? 1 : -1); const ComputeDataType sign = (is_even ? -1 : 1); - return sign * type_convert(input_bhsd(i_b, i_h, i_s, pos)); + return sign * type_convert(input_bsd(i_b, i_s, pos)); } else { const index_t pos = (i_d + half_hdim) % hdim; const ComputeDataType sign = (pos < half_hdim ? 1 : -1); - return sign * type_convert(input_bhsd(i_b, i_h, i_s, pos)); + return sign * type_convert(input_bsd(i_b, i_s, pos)); } }(); ComputeDataType result = - type_convert(input_bhsd(i)) * cos + half_rotated_input * sin; + type_convert(input_bsd(i)) * cos + half_rotated_input * sin; self(i) = type_convert(result); });