diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index cab807a4d2..deebe90bf7 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -15,12 +15,12 @@ #include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp" +#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" -#include "ck_tile/host/reference/reference_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/timer.hpp" diff --git a/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp b/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp similarity index 82% rename from include/ck_tile/host/reference/reference_rotary_position_embedding.hpp rename to include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp index 91d4467fee..9f031a174d 100644 --- a/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp +++ b/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp @@ -16,11 +16,11 @@ namespace detail { } template -CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor& input_bsd, - const HostTensor& cos_sd, - const HostTensor& sin_sd, - bool interleaved, - HostTensor& output_bsd) +CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor& input_bsd, + const HostTensor& cos_sd, + const HostTensor& sin_sd, + bool interleaved, + HostTensor& output_bsd) { assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2); assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&